An easy-to-use platform for EEG experimentation in the classroom
at main 213 lines 6.7 kB view raw
1from collections import OrderedDict 2 3import numpy as np 4from matplotlib import pyplot as plt 5import pandas as pd # maybe we can remove this dependency 6 7from mne import (concatenate_raws, create_info, viz) 8from mne.io import RawArray 9from io import StringIO 10 11# import seaborn as sns 12# plt.style.use(fivethirtyeight) 13# sns.set_context('talk') 14# sns.set_style('white') 15 16 17def load_data(sfreq=128., replace_ch_names=None): 18 """Load CSV files from the /data directory into a RawArray object. 19 20 Parameters 21 ---------- 22 23 sfreq : float 24 EEG sampling frequency 25 replace_ch_names : dict | None 26 A dict containing a mapping to rename channels. 27 Useful when an external electrode was used during recording. 28 29 Returns 30 ------- 31 raw : an instance of mne.io.RawArray 32 The loaded data. 33 """ 34 ## js is loaded in loadPackages 35 ## TODO: Received attached variable name 36 raw = [] 37 for csv in js.csvArray: 38 string_io = StringIO(csv) 39 # read the file 40 data = pd.read_csv(string_io, index_col=0) 41 42 data = data.dropna() 43 44 # get estimation of sampling rate and use to determine sfreq 45 # yes, this could probably be improved 46 srate = 1000 / (data.index.values[1] - data.index.values[0]) 47 if srate >= 200: 48 sfreq = 256 49 else: 50 sfreq = 128 51 52 # name of each channel 53 ch_names = list(data.columns) 54 55 # indices of each channel 56 ch_ind = list(range(len(ch_names))) 57 58 if replace_ch_names is not None: 59 ch_names = [c if c not in replace_ch_names.keys() 60 else replace_ch_names[c] for c in ch_names] 61 62 # type of each channels 63 ch_types = ['eeg'] * (len(ch_ind) - 1) + ['stim'] 64 65 # get data and exclude Aux channel 66 data = data.values[:, ch_ind].T 67 68 # create MNE object 69 info = create_info(ch_names=ch_names, ch_types=ch_types, 70 sfreq=sfreq) 71 raw.append(RawArray(data=data, info=info).set_montage('standard_1005')) 72 73 # concatenate all raw objects 74 raws = concatenate_raws(raw) 75 76 return raws 77 78 79def plot_topo(epochs, conditions=OrderedDict()): 80 # palette = sns.color_palette("hls", len(conditions) + 1) 81 # temp hack, just pull in the color palette from seaborn 82 palette = [(0.85999999999999999, 0.37119999999999997, 0.33999999999999997), 83 (0.33999999999999997, 0.85999999999999999, 0.37119999999999997), 84 (0.37119999999999997, 0.33999999999999997, 0.85999999999999999)] 85 evokeds = [epochs[name].average() for name in (conditions)] 86 87 evoked_topo = viz.plot_evoked_topo( 88 evokeds, vline=None, color=palette[0:len(conditions)], show=False) 89 evoked_topo.patch.set_alpha(0) 90 evoked_topo.set_size_inches(10, 8) 91 for axis in evoked_topo.axes: 92 for line in axis.lines: 93 line.set_linewidth(2) 94 95 legend_loc = 0 96 labels = [e.comment if e.comment else 'Unknown' for e in evokeds] 97 legend = plt.legend(labels, loc=legend_loc, prop={'size': 20}) 98 txts = legend.get_texts() 99 for txt, col in zip(txts, palette): 100 txt.set_color(col) 101 102 return evoked_topo 103 104 105def plot_conditions(epochs, ch_ind=0, conditions=OrderedDict(), ci=97.5, 106 n_boot=1000, title='', palette=None, diff_waveform=(4, 3)): 107 """Plot Averaged Epochs with ERP conditions. 108 109 Parameters 110 ---------- 111 epochs : an instance of mne.epochs 112 EEG epochs 113 conditions : an instance of OrderedDict 114 An ordered dictionary that contains the names of the 115 conditions to plot as keys, and the list of corresponding marker 116 numbers as value. 117 118 E.g., 119 120 conditions = {'Non-target': [0, 1], 121 'Target': [2, 3, 4]} 122 123 ch_ind : int 124 An index of channel to plot data from. 125 ci : float 126 The confidence interval of the measurement within 127 the range [0, 100]. 128 n_boot : int 129 Number of bootstrap samples. 130 title : str 131 Title of the figure. 132 palette : list 133 Color palette to use for conditions. 134 ylim : tuple 135 (ymin, ymax) 136 diff_waveform : tuple | None 137 tuple of ints indicating which conditions to subtract for 138 producing the difference waveform. 139 If None, do not plot a difference waveform 140 141 Returns 142 ------- 143 fig : an instance of matplotlib.figure.Figure 144 A figure object. 145 ax : list of matplotlib.axes._subplots.AxesSubplot 146 A list of axes 147 """ 148 if isinstance(conditions, dict): 149 conditions = OrderedDict(conditions) 150 151 if palette is None: 152 palette = [ 153 (0.86, 0.37, 0.34), 154 (0.34, 0.86, 0.37), 155 (0.37, 0.34, 0.86), 156 (0.86, 0.72, 0.34), 157 ] 158 159 X = epochs.get_data() 160 times = epochs.times 161 y = pd.Series(epochs.events[:, -1]) 162 fig, ax = plt.subplots() 163 164 for cond, color in zip(conditions.values(), palette): 165 cond_data = X[y.isin(cond), ch_ind] 166 mean = np.nanmean(cond_data, axis=0) 167 n_samples = cond_data.shape[0] 168 boot_means = np.array([ 169 np.nanmean( 170 cond_data[np.random.randint(0, n_samples, n_samples)], axis=0 171 ) 172 for _ in range(n_boot) 173 ]) 174 alpha = (100 - ci) / 2 175 low = np.percentile(boot_means, alpha, axis=0) 176 high = np.percentile(boot_means, 100 - alpha, axis=0) 177 ax.plot(times, mean, color=color) 178 ax.fill_between(times, low, high, color=color, alpha=0.3) 179 180 if diff_waveform: 181 diff = (np.nanmean(X[y == diff_waveform[1], ch_ind], axis=0) - 182 np.nanmean(X[y == diff_waveform[0], ch_ind], axis=0)) 183 ax.plot(times, diff, color='k', lw=1) 184 185 ax.set_title(epochs.ch_names[ch_ind]) 186 ax.axvline(x=0, color='k', lw=1, label='_nolegend_') 187 188 ax.set_xlabel('Time (s)') 189 ax.set_ylabel('Amplitude (uV)') 190 191 if diff_waveform: 192 legend = (['{} - {}'.format(diff_waveform[1], diff_waveform[0])] + 193 list(conditions.keys())) 194 else: 195 legend = conditions.keys() 196 ax.legend(legend) 197 ax.spines['top'].set_visible(False) 198 ax.spines['right'].set_visible(False) 199 plt.tight_layout() 200 201 if title: 202 fig.suptitle(title, fontsize=20) 203 204 fig.set_size_inches(10, 8) 205 206 return fig, ax 207 208def get_epochs_info(epochs): 209 print('Get Epochs Info:') 210 return [*[{x: len(epochs[x])} for x in epochs.event_id], 211 {"Drop Percentage": round((1 - len(epochs.events) / 212 len(epochs.drop_log)) * 100, 2)}, 213 {"Total Epochs": len(epochs.events)}]