Select Git revision
artefact_force.py 8.66 KiB
import pywt
import concurrent.futures
import matplotlib.pyplot as plt
from mne.io import read_raw_edf
from preprocess.artefact_1_AMI import *
from preprocess.artefact_2_Spiking import *
from preprocess.artefact_3_Kurtosis import *
from preprocess.artefact_4_5_PSD import *
from preprocess.artefact_6_Projection_STD import *
from preprocess.artefact_7_Topographic_distribution import *
from preprocess.artefact_8_Amplitude_thresholding import *
from preprocess.artefact_Spike_zone_thresholding import *
from preprocess.artefact_PSD import *
from sklearn.decomposition import FastICA
""" This implementation is based on FORCe: Fully Online and Automated Artifact
Removal for Brain-Computer Interfacing by Ian Daly, Reinhold Scherer, Martin Billinger, and Gernot Müller-Putz."""
""" Constants: """
# 4.1 AMI
LAG_OFFSET = 60
THRESHOLD_MAX = 2.0
THRESHOLD_MIN = 1.0
# 4.2 Spiking
THRESHOLD_COEFF = 0.25
# 4.4-5 PSD + Gamma
DISTANCE = 3.5
THRESHOLD_GAMMA = 1.7
# 4.8 Large amplitude removing
THRESHOLD_AMPLITUDE = 100
PEAK_DIFFERENCE = 60
# 5 Spike-zone thresholding
DC_checkVal = 0.2
DC_adjustVal = 0.07
AC_checkVal = 0.7
AC_adjustVal = 0.8
REMOVING_THRESHOLD = 3
def FORCe(info, original_data):
nCh = len(info['chs'])
data = np.multiply(original_data, 1e6)
""" Step 1) - 4) """
waveletname = 'sym4'
coeffitiantsLevel1 = []
coeffitiantsLevel2 = []
cA2 = []
cD1 = []
cD2 = []
marked_ICs = np.zeros(nCh)
"""
# ''cAx'' is he array of approximation coefficient (low pass filters) and
# ''cDx'' is the list of detail coefficients (high pass filters).
# 2 level decomposition!
"""
for i in range(np.size(data, 0)):
actual_coeff = pywt.wavedec(data[i][:], waveletname, level=1)
cA1_actual, cD1_actual = actual_coeff
coeffitiantsLevel1.append(actual_coeff)
cD1.append(cD1_actual)
D1 = np.array(cD1)
for i in range(np.size(data, 0)):
actual_coeff = pywt.wavedec(data[i][:], waveletname, level=2)
cA2_actual, cD2_actual, cD1_actual = actual_coeff
coeffitiantsLevel2.append(actual_coeff)
cA2.append(cA2_actual)
cD2.append(cD2_actual)
A2 = np.array(cA2)
D2 = np.array(cD2)
"""
# S: 2D array containing estimated source signals
# A: 2D array containing mixing matrix, i.e. A.dot(S) = X
"""
ica = FastICA(n_components=len(info['chs']), algorithm='parallel')
S = ica.fit(A2.T).transform(A2.T)
A = ica.mixing_
assert np.allclose(A2.T, np.dot(S, A.T) + ica.mean_)
assert A.shape[0] == A.shape[1]
ICs_projections = np.zeros((np.size(S, 0), np.size(S, 1), np.size(S, 1)))
for i in range(np.size(S, 1)):
actual_S = np.copy(S)
for j in range(np.size(S, 1)):
if (j != i):
actual_S[:, j] = 0
actual_projection = np.dot(actual_S, A.T)
ICs_projections[:, :, i] = actual_projection
with concurrent.futures.ThreadPoolExecutor() as executor:
toRemoveICs1 = executor.submit(Auto_Mutual_Information, ICs_projections, LAG_OFFSET, THRESHOLD_MAX, THRESHOLD_MIN)
toRemoveICs2 = executor.submit(Spiking_activity, S, THRESHOLD_COEFF)
toRemoveICs3 = executor.submit(Kurtosis, ICs_projections)
toRemoveICs4 = executor.submit(PSD, ICs_projections, info['sfreq'], DISTANCE, THRESHOLD_GAMMA)
toRemoveICs5 = executor.submit(Projection_STD, ICs_projections)
toRemoveICs6 = executor.submit(Topographic_distribution, ICs_projections, info)
toRemoveICs7 = executor.submit(Amplitude_thresholding, ICs_projections, THRESHOLD_AMPLITUDE, PEAK_DIFFERENCE)
for i in range(np.size(marked_ICs)):
marked_ICs[i] = toRemoveICs1.result()[i] + toRemoveICs2.result()[i] + toRemoveICs3.result()[i] + toRemoveICs4.result()[i] + toRemoveICs5.result()[i] + toRemoveICs6.result()[i] + toRemoveICs7.result()[i]
# Without paralellisation:
""" Step 4).1: Auto-Mutal Informatin """
# toRemoveICs1 = Auto_Mutual_Information(ICs_projections, LAG_OFFSET, THRESHOLD_MAX, THRESHOLD_MIN, marked_ICs)
# print('Marked to remove by AMI: ', toRemoveICs1)
""" Step 4).2: Spiking activity """
# toRemoveICs2a, toRemoveICs2b = Spiking_activity(S, THRESHOLD_COEFF, marked_ICs)
# print('Marked to remove by Spiking-activity: ', toRemoveICs2a)
# print('Marked to remove by Spike zone coefficients: ', toRemoveICs2b)
""" 4).3 Kurtosis """
# toRemoveICs3 = Kurtosis(ICs_projections, marked_ICs)
# print('Marked to remove by Kurtosis : ', toRemoveICs3)
""" 4).4-5 Check PSDs + PSD of gamma frequency """
# toRemoveICs4, toRemoveICs5 = PSD(ICs_projections, info['sfreq'], DISTANCE, THRESHOLD_GAMMA, marked_ICs)
# print('Marked to remove by PSD & 1/F distribution: ', toRemoveICs4)
# print('Marked to remove by Gamma frequency: ', toRemoveICs5)
""" 4).6 Check stds of projections of ICs. """
# toRemoveICs6 = Projection_STD(ICs_projections, marked_ICs)
# print('Marked to remove by Std: ', toRemoveICs6)
""" 4).7 Topographic distribution of standard deviations """
# toRemoveICs7 = Topographic_distribution(ICs_projections, info, marked_ICs)
# print('Marked to remove by topographic distribution: ', toRemoveICs7)
""" 4).8 Remove ICs with large amplitudes """
# toRemoveICs8a, toRemoveICs8b = Amplitude_thresholding(ICs_projections, THRESHOLD_AMPLITUDE, PEAK_DIFFERENCE,
# marked_ICs)
# print('Marked to remove by Large amplitudes A: ', toRemoveICs8a)
# print('Marked to remove by Large amplitudes B: ', toRemoveICs8b)
for i in range(np.size(marked_ICs)):
if (marked_ICs[i] > REMOVING_THRESHOLD):
S[:, i] = 0
clean_A2 = (ica.inverse_transform(S)).T
""" 5) Spike Zone Thresholding """
newD1 = Thresholding(D1, DC_checkVal, DC_adjustVal)
newA2 = Thresholding(clean_A2, AC_checkVal, AC_adjustVal)
newD2 = Thresholding(D2, DC_checkVal, DC_adjustVal)
print("Number of channel markings: ", marked_ICs)
pre_clea_data = []
for i in range(np.size(D2, 0)):
coeff = [newA2[i], newD2[i], newD1[i]]
actual_reconstuction = pywt.waverec(coeff, waveletname)
pre_clea_data.append(actual_reconstuction)
clean_data = np.array(pre_clea_data)
return clean_data
class ArtefactFilter:
def offline_filter(self, epochs):
"""Offline Faster algorithm
Filters the input epochs, and saves the parameters (such as ICA weights),
for the possibility of online filtering
Parameters
----------
epochs : mne.Epochs
The epochs to analyze
Returns
-------
mne.Epochs
The filtered epoch
"""
for i in range(len(epochs)):
epochs._data[i, ...] = FORCe(epochs.info, epochs[i].get_data()[0])[:,:-1]
return epochs
if __name__ == '__main__':
file_name = 'F:/ÖNLAB/Databases/physionet.org/physiobank/database/eegmmidb/S001/S001R03.edf'
raw = read_raw_edf(file_name)
raw_croped = raw.crop(tmax=60).load_data()
raw_croped.resample(500)
data = raw_croped.get_data()
sampling_freq = raw.info['sfreq']
chanel_number = np.size(raw.info['ch_names'])
""" The following parameters must always be set manually! """
windowLengthCoeff = 1
windowLength = int(windowLengthCoeff * sampling_freq)
# Whole data length:
N = windowLength * (int(np.size(data, 1) / windowLength))
# Just for a partition of data:
# N = windowLength * 10
rest = np.size(data, 1) - N
preEEG_clean = []
for windowPosition in range(0, N, windowLength):
window = np.arange(windowPosition, (windowPosition + windowLength), 1, dtype=int)
preEEG_clean.append(FORCe(raw_croped.info, data[:, window]))
print('--- ' + str(window[0]) + '-' + str(window[-1]) + ' ---')
EEG_clean = np.concatenate(preEEG_clean, axis=1)
""" -- Plot all of chanels -- """
# for i in range(np.size(EEG_clean, 0)):
# plt.subplot(np.size(EEG_clean, 0), 1, i+1)
# plt.plot(EEG_clean[i, :])
# plt.show()
""" -- Plot chanels grouped 8 chanels -- """
# for i in range(int(np.size(EEG_clean, 0)/8)):
# for j in range(8):
# plt.subplot(np.size(EEG_clean, 0)/8, 1, j+1)
# plt.plot(np.multiply(data[(i*8)+j, :], 1e6), 'r')
# plt.plot(EEG_clean[(i * 8)+j, :])
# plt.show()
""" -- Plot chanels grouped 16 chanels -- """
for i in range(int(np.size(data, 0) / 16)):
for j in range(16):
plt.subplot(np.size(EEG_clean, 0) / 4, 1, j + 1)
plt.plot(np.multiply(data[(i * 16) + j, :], 1e6), 'r')
plt.plot(EEG_clean[(i * 16) + j, :])
plt.ylabel(raw.info['ch_names'][(i * 16)+j])
plt.show()