Using Machine Learning to Identify The P300 Wave — Brain-Computer Interfaces

Original article can be found here (source): Artificial Intelligence on Medium

Using Machine Learning to Identify The P300 Wave — Brain-Computer Interfaces

How does the brain really make decisions? What is consciousness? Why do we need sleep? 🤔

Image from https://www.wired.com/2017/03/elon-musks-neural-lace-really-look-like/

All super complicated questions — that will require a solution no less complicated to find out what actually goes on in our heads to make us who we are. By using neural networks, we begin to chip away at the mysteriousness of the brain by being able to classify and identify patterns in complicated data.

In this article, I will describe how I use Machine Learning to identify a P300 wave, the brainwave that indicates the presence of a rare or unexpected event.

Today, the best way to non-invasively gather data on the VERY complex human brain is through using voltage-reading sensors. The data gathered requires complex analysis to differentiate between the different brainwaves present all at the same time.

EEG’s are often used to collect brainwave data. This image is from https://www.verywellhealth.com/what-is-an-eeg-test-and-what-is-it-used-for-3014879

To make it even harder though, everyone’s brain is different. This means the interface requires tailoring to be able to accurately read and analyze the brainwave of every single individual.

In my neural network, I analyze the P300 brainwave. This wave corresponds to decision making in the brain, and is mostly recorded in the parietal lobe. An example of a decision we can test for in a lab is whether the patient has decided to press a button with their right hand or not-and then we analyze the changes in the P300 wave.

How the data was collected:

A subject was shown a set of images — checkerboard patterns and smiley faces— and told to press a key with the right index finger as quickly as they could when they see the smiley face.

Basically, I use my neural network to detect when the smiley face was shown and the key was pressed!

I used the MNE dataset recorded in this experiment. In this library, each sample represents one trial and contains 226 data points. Below is a plot of two such samples, and the channel is shown on the top right of the graph.

Image generated from the neural network.

So what is the ‘P300’?

P3 is a positive deflection, or a positive spike, recorded in our brainwaves around 300 milliseconds after we observe an anticipated or surprising stimulus. The P300 is often used to make a speller in BCI’s — patient can focus on the letters displayed on the screen in order to type them. An example can be seen here: https://www.youtube.com/watch?v=XIr2cRKFolY

Screenshot from Cognionics Dry EEG P300 Speller Demo

Why is a neural network necessary for detecting the P300?

Summary of what a neural network does:

A neuron is composed of dendrites, which act as the inputs, and axons, which act as the outputs. When a neuron receives a signal, if the signal passes a firing threshold, it moves into the neuron’s axon and travels to other neurons down the line.

Image of neuron from https://training.seer.cancer.gov/brain/tumors/anatomy/neurons.html

Some neurons have stronger connections than others — they have a “weight” of connection. Neural networks function similarly, where each neuron depends on a weight and “activation algorithm” to calculate its output.

In order for neural networks to be able to calculate an accurate output, as it does not have an equation it can rely on, we must feed it data and allow it to train creating a curve that allows it to output an accurate decision.

What types of problems are neural networks effective in:

Problems where there is so set equation, the input varies, and there may be unknown variables at play are those where neural networks are the best solution. This is exactly why a neural network is needed in my case, because:

  1. I do not have the equation to analyze for P300 waves
  2. The P300 wave varies from brain to brain
  3. There may be non-linear noise (noise that we can’t simply subtract from the data)

So need to use a neural network to figure out the best way to analyze for this wave!

How is the data represented in the network:

Each of the 226 data points would be a different dimension to the data. You may be wondering, why not two dimensional like it is on this graph?

Image generated from the neural network.

Each sample of a brainwave has 266 data points, all with a different (Time(s), nV) value, representing it. If we use a 266 dimensional graph to plot a point representing one brainwave, it allows us to draw a 266 dimensional curve to separate them, the brainwaves, accurately. To picture this, all of he 266 values that make up the brainwave influence its position on the plot, and therefore, will determine what side of the curve they will be on — and the curve decides whether the brainwave is P300!

While it is hard to picture something like this, it is even harder to solve a problem of this magnitude by hand using mathematical tools for dimensional problems.

Building the Machine Learning algorithm:

In my project, I use:

  • MNE — EEG dataset
  • NumPy and SciPy — Scientific Computing libraries
  • Scikit-Learn and PyTorch — Machine Learning libraries

The neural network has distinct components, and is made up of sets of linear and activation nodes. Our input data will go through the network, until it finally passes through a sigmoid function. The final value is scaled to be a number between 0 and 1. The closer the number is to 1, the more confident the network is that the data represents a P300 wave.

The network code is shown below. It is composed of:

  • 1 input layer
  • 3 hidden layers
  • 1 output layer
eeg_model = nn.Sequential()# input layer
eeg_model.add_module('Input Linear', nn.Linear(eeg_sample_length, hidden1))
eeg_model.add_module('Input Activation', nn.CELU())
# hidden layers
eeg_model.add_module('Hidden Linear 1', nn.Linear(hidden1, hidden2))
eeg_model.add_module('Hidden Activation', nn.ReLU())
eeg_model.add_module('Hidden Linear 2', nn.Linear(hidden2, hidden3))
eeg_model.add_module('Hidden Activation2', nn.ReLU())
eeg_model.add_module('Hidden Linear 3', nn.Linear(hidden3, 10))
eeg_model.add_module('Hidden Activation3', nn.ReLU())
eeg_model.add_module('Output Linear', nn.Linear(10, number_of_classes))
eeg_model.add_module('Output Activation', nn.Sigmoid())

Then I setup the loss and training functions to tweak the network until is can accurately classify the brainwave data. The smaller the accuracy the network performs at, the bigger the loss. So, we want to see a very small or zero loss when the network is done training.

Below, I used PyTorch to make a loss function, update the network, and calculate loss per iteration.

#definition of loss function
loss_func = torch.nn.MSELoss()

#training procedure definition
def train_network(train_data, actual_class, iterations):

loss_data = []

for i in range(iterations):

#start with classification
classification = eeg_model(train_data)

#then find out how wrong the algorithm was
loss = loss_func(classification, actual_class)
loss_data.append(loss)

#zero out optimizer gradients per iteration
optimizer.zero_grad()

#optimize the network to do better the next iteration
loss.backward()
optimizer.step()

Train and Test The Network On Sample Data

When you train to do something, such as learning how to add, you train on one set of questions. The test your teacher gives you to see that you have learned how to add will contain questions with addition, but not the exact same ones. This determines whether you really learned the concept of addition rather than just memorized all the answers to your practice questions. In my network, I will first train it on one set of brainwave data.

Image from https://www.wms-partners.com/math-over-mood/

First, I am going to test if the network is working correctly with easy to classify data. I graph the dataset to have a visual representation of what is going on. The graph can be seen below. The “good” sample average is represented by the green line (it has values around 0.5 and a larger amplitude than the other two, “bad” classes. The two “bad” classes have a smaller amplitude and are centred around 0.25 and 0.75.

Then, I will design the network around this sample data. if it can classify the sample data correctly, that the likelihood that it classifies the more complex dataset correctly increases. I then run the network on the remaining sample data to see if it learned correctly — and it does!

Retrieve the Data from the MNE Dataset

In order to access the database, I set the path to the specific dataset I am using: the database containing brainwaves filtered from 0 to 40 Hz.

data_path = mne.datasets.sample.data_path()
...
raw_fname = data_path + '/MEG/sample/sample_audvis_filt-0-40_raw.fif'
event_fname = data_path + '/MEG/sample/sample_audvis_filt-0-40_raw-eve.fif'
rawdata = mne.io.read_raw_fif(raw_fname, preload=True)2rawdata.set_eeg_reference()

This data is not sliced at a specific event(not processed) — it is a collection of brainwave samples collected from numerous EEG channels. In the code, I load and save the dataset under rawdata. This dataset contains magnetoencephalography data, represented by parameter meg, however, in this project I only want to analyze for EEG data, so meg is set to False.

rawdata = rawdata.pick(picks=["eeg","eog"])
picks_eeg_only = mne.pick_types(rawdata.info,
eeg=True,
eog=True,
meg=False,
exclude='bads')

Once I load the data, I slice it (epoch it) 0.5 seconds before the smiley image was presented and 1 second after. It is important to note that this dataset only has 12 examples of the P300 wave, and it would have been better to train the neural network on at least 100 examples.

Then, I will take a similar amount of data that does not contain the P300 wave for comparison, and mix it with the rest of the data.

Classification With Neural Network

To train the network using the data I just sliced, check out the code I use:

eeg_model = torch.load("/home/eeg_model_default_state")#learning function definition
optimizer = torch.optim.Adam(eeg_model.parameters(), lr = learning_rate)
#use the training procedure on sample data
print("Below is the loss graph for dataset training session")
train_network(training_data, labels, iterations = 50)

This is the loss graph below showing us that the network has gets better at identifying the P300 wave!

Testing The Neural Network

Here are some of the classification results:

P300 Positive Classification 1: 100.00%
P300 Positive Classification 2: 99.94%
P300 Positive Classification 3: 100.00%
P300 Negative Classification 1: 99.92%
P300 Negative Classification 2: 0.04%
P300 Negative Classification 3: 0.00%

Hmmmm. Two of the negative samples were misclassified! This is because we did not have enough data to train on, which means that the network was not able to see the difference between some of the Negative and Positive samples. However, for a network trained on such FEW samples, it did very well! This proves that with more data, it is likely that the network will be able to accurately identify the P300 wave!