An easy-to-use platform for EEG experimentation in the classroom
1from collections import OrderedDict
2
3import numpy as np
4from matplotlib import pyplot as plt
5import pandas as pd # maybe we can remove this dependency
6
7from mne import (concatenate_raws, create_info, viz)
8from mne.io import RawArray
9from io import StringIO
10
11# import seaborn as sns
12# plt.style.use(fivethirtyeight)
13# sns.set_context('talk')
14# sns.set_style('white')
15
16
17def load_data(sfreq=128., replace_ch_names=None):
18 """Load CSV files from the /data directory into a RawArray object.
19
20 Parameters
21 ----------
22
23 sfreq : float
24 EEG sampling frequency
25 replace_ch_names : dict | None
26 A dict containing a mapping to rename channels.
27 Useful when an external electrode was used during recording.
28
29 Returns
30 -------
31 raw : an instance of mne.io.RawArray
32 The loaded data.
33 """
34 ## js is loaded in loadPackages
35 ## TODO: Received attached variable name
36 raw = []
37 for csv in js.csvArray:
38 string_io = StringIO(csv)
39 # read the file
40 data = pd.read_csv(string_io, index_col=0)
41
42 data = data.dropna()
43
44 # get estimation of sampling rate and use to determine sfreq
45 # yes, this could probably be improved
46 srate = 1000 / (data.index.values[1] - data.index.values[0])
47 if srate >= 200:
48 sfreq = 256
49 else:
50 sfreq = 128
51
52 # name of each channel
53 ch_names = list(data.columns)
54
55 # indices of each channel
56 ch_ind = list(range(len(ch_names)))
57
58 if replace_ch_names is not None:
59 ch_names = [c if c not in replace_ch_names.keys()
60 else replace_ch_names[c] for c in ch_names]
61
62 # type of each channels
63 ch_types = ['eeg'] * (len(ch_ind) - 1) + ['stim']
64
65 # get data and exclude Aux channel
66 data = data.values[:, ch_ind].T
67
68 # create MNE object
69 info = create_info(ch_names=ch_names, ch_types=ch_types,
70 sfreq=sfreq)
71 raw.append(RawArray(data=data, info=info).set_montage('standard_1005'))
72
73 # concatenate all raw objects
74 raws = concatenate_raws(raw)
75
76 return raws
77
78
79def plot_topo(epochs, conditions=OrderedDict()):
80 # palette = sns.color_palette("hls", len(conditions) + 1)
81 # temp hack, just pull in the color palette from seaborn
82 palette = [(0.85999999999999999, 0.37119999999999997, 0.33999999999999997),
83 (0.33999999999999997, 0.85999999999999999, 0.37119999999999997),
84 (0.37119999999999997, 0.33999999999999997, 0.85999999999999999)]
85 evokeds = [epochs[name].average() for name in (conditions)]
86
87 evoked_topo = viz.plot_evoked_topo(
88 evokeds, vline=None, color=palette[0:len(conditions)], show=False)
89 evoked_topo.patch.set_alpha(0)
90 evoked_topo.set_size_inches(10, 8)
91 for axis in evoked_topo.axes:
92 for line in axis.lines:
93 line.set_linewidth(2)
94
95 legend_loc = 0
96 labels = [e.comment if e.comment else 'Unknown' for e in evokeds]
97 legend = plt.legend(labels, loc=legend_loc, prop={'size': 20})
98 txts = legend.get_texts()
99 for txt, col in zip(txts, palette):
100 txt.set_color(col)
101
102 return evoked_topo
103
104
105def plot_conditions(epochs, ch_ind=0, conditions=OrderedDict(), ci=97.5,
106 n_boot=1000, title='', palette=None, diff_waveform=(4, 3)):
107 """Plot Averaged Epochs with ERP conditions.
108
109 Parameters
110 ----------
111 epochs : an instance of mne.epochs
112 EEG epochs
113 conditions : an instance of OrderedDict
114 An ordered dictionary that contains the names of the
115 conditions to plot as keys, and the list of corresponding marker
116 numbers as value.
117
118 E.g.,
119
120 conditions = {'Non-target': [0, 1],
121 'Target': [2, 3, 4]}
122
123 ch_ind : int
124 An index of channel to plot data from.
125 ci : float
126 The confidence interval of the measurement within
127 the range [0, 100].
128 n_boot : int
129 Number of bootstrap samples.
130 title : str
131 Title of the figure.
132 palette : list
133 Color palette to use for conditions.
134 ylim : tuple
135 (ymin, ymax)
136 diff_waveform : tuple | None
137 tuple of ints indicating which conditions to subtract for
138 producing the difference waveform.
139 If None, do not plot a difference waveform
140
141 Returns
142 -------
143 fig : an instance of matplotlib.figure.Figure
144 A figure object.
145 ax : list of matplotlib.axes._subplots.AxesSubplot
146 A list of axes
147 """
148 if isinstance(conditions, dict):
149 conditions = OrderedDict(conditions)
150
151 if palette is None:
152 palette = [
153 (0.86, 0.37, 0.34),
154 (0.34, 0.86, 0.37),
155 (0.37, 0.34, 0.86),
156 (0.86, 0.72, 0.34),
157 ]
158
159 X = epochs.get_data()
160 times = epochs.times
161 y = pd.Series(epochs.events[:, -1])
162 fig, ax = plt.subplots()
163
164 for cond, color in zip(conditions.values(), palette):
165 cond_data = X[y.isin(cond), ch_ind]
166 mean = np.nanmean(cond_data, axis=0)
167 n_samples = cond_data.shape[0]
168 boot_means = np.array([
169 np.nanmean(
170 cond_data[np.random.randint(0, n_samples, n_samples)], axis=0
171 )
172 for _ in range(n_boot)
173 ])
174 alpha = (100 - ci) / 2
175 low = np.percentile(boot_means, alpha, axis=0)
176 high = np.percentile(boot_means, 100 - alpha, axis=0)
177 ax.plot(times, mean, color=color)
178 ax.fill_between(times, low, high, color=color, alpha=0.3)
179
180 if diff_waveform:
181 diff = (np.nanmean(X[y == diff_waveform[1], ch_ind], axis=0) -
182 np.nanmean(X[y == diff_waveform[0], ch_ind], axis=0))
183 ax.plot(times, diff, color='k', lw=1)
184
185 ax.set_title(epochs.ch_names[ch_ind])
186 ax.axvline(x=0, color='k', lw=1, label='_nolegend_')
187
188 ax.set_xlabel('Time (s)')
189 ax.set_ylabel('Amplitude (uV)')
190
191 if diff_waveform:
192 legend = (['{} - {}'.format(diff_waveform[1], diff_waveform[0])] +
193 list(conditions.keys()))
194 else:
195 legend = conditions.keys()
196 ax.legend(legend)
197 ax.spines['top'].set_visible(False)
198 ax.spines['right'].set_visible(False)
199 plt.tight_layout()
200
201 if title:
202 fig.suptitle(title, fontsize=20)
203
204 fig.set_size_inches(10, 8)
205
206 return fig, ax
207
208def get_epochs_info(epochs):
209 print('Get Epochs Info:')
210 return [*[{x: len(epochs[x])} for x in epochs.event_id],
211 {"Drop Percentage": round((1 - len(epochs.events) /
212 len(epochs.drop_log)) * 100, 2)},
213 {"Total Epochs": len(epochs.events)}]