import os
import os.path as op
import numpy as np
from contextlib import contextmanager
[docs]def find_range(vec, ranges):
'''
Find specified ranges in an ordered vector and return them as slices.
Parameters
----------
vec : numpy array
Vector of sorted values.
ranges: list of tuples/lists | two-element list/tuple
Ranges or range to be found.
Returns
-------
slices : slice or list of slices
Slices representing the ranges. If one range was passed the output
is a slice. If two or more ranges were passed the output is a list
of slices.
'''
assert isinstance(ranges, (list, tuple))
assert len(ranges) > 0
one_in = False
if not isinstance(ranges[0], (list, tuple)) and len(ranges) == 2:
one_in = True
ranges = [ranges]
slices = list()
for rng in ranges:
start, stop = [np.abs(vec - x).argmin() for x in rng]
slices.append(slice(start, stop + 1)) # including last index
if one_in:
slices = slices[0]
return slices
# - [ ] if vals is np.ndarray try to format output in the right shape
[docs]def find_index(vec, vals):
'''
Find indices of values in `vec` that are closest to requested values
`vals`.
Parameters
----------
vec : numpy array
Vector of values.
vals: list of values | value
Values to find closest representatives of in the `vec` vector.
Returns
-------
idx : numpy array of int | int
Indices of `vec` values closest to `vals`. If one value was passed in
`vals` then `idx` will also be one value. If two or more values were
passed in `vals` the output is a numpy array of indices.
'''
one_in = False
if not isinstance(vals, (list, tuple, np.ndarray)):
one_in = True
vals = [vals]
outlist = [np.abs(vec - x).argmin() for x in vals]
if one_in:
return outlist[0]
else:
return np.array(outlist)
[docs]def get_info(inst):
'''Simple helper function that returns Info whatever mne object it gets.'''
from mne import Info
if isinstance(inst, Info):
return inst
else:
return inst.info
[docs]def detect_overlap(segment, annot, sfreq=None):
'''
Detect what percentage of given segment is overlapping with annotations.
Parameters
----------
segment : list or 1d array
Two-element list or array of [start, stop] values.
annot : mne.Annotation of 2d numpy array
Annotations or 2d array of N x 2 (start, stop) values.
sfreq : float
Sampling frequency (default: None). If not None segment is assumed to
be given in samples. `annot` is transformed to samples using sfreq if
it is of mne.Annotations type. If `annot` is np.ndarray then it is
transformed to samples only if its dtype is not 'int64' or 'int32'.
Returns
-------
overlap : float
Percentage overlap in 0 - 1 range.
'''
samples = sfreq is not None
# FIXME - the branching below seems overly complex
# for convenience we accept mne.Annotation objects and numpy arrays:
if not isinstance(annot, np.ndarray):
if annot is None:
return 0.
# if we didn't get numpy array we assume it's mne.Annotation
# and convert it to N x 2 (N x start, end) array:
onset = (annot.onset if not samples else
np.round(annot.onset * sfreq).astype('int'))
duration = (annot.duration if not samples else
np.round(annot.duration * sfreq).astype('int'))
annot_arr = np.hstack([onset[:, np.newaxis], onset[:, np.newaxis]
+ duration[:, np.newaxis]])
else:
if not samples:
annot_arr = annot
else:
in_samples = (annot.dtype is np.dtype('int64') or
annot.dtype is np.dtype('int32'))
# FIXME - if not in_samples throw an error or issue a warning
annot_arr = (annot if in_samples else
np.round(annot * sfreq).astype('int'))
# checks for boundary relationships
ll_beleq = annot_arr[:, 0] <= segment[0]
hh_abveq = annot_arr[:, 1] >= segment[1]
# if any annot's lower edge (le) is below or equal segments lower edge
# and its higher edge (he) is above segments higher edge
# then annot includes segment and the coverage is 100%
if (ll_beleq & hh_abveq).any():
return 1.
# otherwise we perform more checks
hl_abv = annot_arr[:, 1] > segment[0]
segment_length = np.diff(segment)[0]
overlap = 0.
# if any annot's le is below or equal segments le and its he is
# above segments le - the the overlap is from segments le up to annot he
check = ll_beleq & hl_abv
if check.any():
# there should be only one such annot (we assume non-overlapping annot)
overlap += (annot_arr[:, 1][check][0] - segment[0]) / segment_length
ll_abv = ~ll_beleq
hh_bel = ~hh_abveq
# if any annot's le is above segments le and its he is
# below segments he - the the annot is within the segment
# and the overlap is from annot le up to annot he
# (there can be multiple such annots)
check = ll_abv & hh_bel
if check.any():
overlap += (np.diff(annot_arr[check]) / segment_length).sum()
# the only remaining case is when annot he is above segments he
# and its le is above segments le but below segments he
lh_bel = annot_arr[:, 0] < segment[1]
check = hh_abveq & ll_abv & lh_bel
if check.any():
overlap += (segment[1] - annot_arr[:, 0][check][0]) / segment_length
return overlap
# FIXME - add warnings etc.
# FIXME - there should be some mne function for that,
# if so - use that function (check later)
def _check_tmin_tmax(raw, tmin, tmax):
sfreq = raw.info['sfreq']
lowest_tmin = raw.first_samp / sfreq
highest_tmax = (raw.last_samp + 1) / sfreq
tmin = lowest_tmin if tmin is None or tmin < lowest_tmin else tmin
tmax = highest_tmax if tmax is None or tmax > highest_tmax else tmax
return tmin, tmax, sfreq
[docs]def valid_windows(raw, tmin=None, tmax=None, winlen=2., step=1.):
'''
Test which moving windows overlap with annotations.
Parameters
----------
raw : mne.Raw
Data to use.
tmin : flot | None
Start time for the moving windows. Defaults to None which means start
of the raw data.
tmax : flot | None
End time for the moving windows. Defaults to None which means end of
the raw data.
winlen : float
Window length in seconds. Defaults to 2.
step : float
Window step in seconds. Defaults to 1.
Returns
-------
valid : boolean numpy array
Whether the moving widnows overlap with annotations. Consecutive values
inform whether consecutive windows overlap with any annotation.
'''
annot = raw.annotations
tmin, tmax, sfreq = _check_tmin_tmax(raw, tmin, tmax)
step = int(round(step * sfreq))
winlen = int(round(winlen * sfreq))
tmin_smp, tmax_smp = int(round(tmin * sfreq)), int(round(tmax * sfreq))
n_windows = int((tmax_smp - tmin_smp - winlen + step) / step)
valid = np.zeros(n_windows, dtype='bool')
for win_idx in range(n_windows):
start = tmin_smp + win_idx * step
segment = [start, start + winlen]
overlap = detect_overlap(segment, annot, sfreq=sfreq)
valid[win_idx] = overlap == 0.
return valid
[docs]def create_fake_raw(n_channels=4, n_samples=100, sfreq=125.):
'''
Create fake raw signal for testing.
Parameters
----------
n_channels : int, optional
Number of channels in the fake raw signal. Defaults to 4.
n_samples : int, optional
Number of samples in the fake raw singal. Defaults to 100.
sfreq : float, optional
Sampling frequency of the fake raw signal. Defaults to 125.
Returns
-------
raw : mne.io.RawArray
Created raw array.
'''
import mne
from string import ascii_letters
ch_names = list(ascii_letters[:n_channels])
data = np.zeros((n_channels, n_samples))
info = mne.create_info(ch_names, sfreq, ch_types='eeg')
return mne.io.RawArray(data, info)
[docs]def get_dropped_epochs(epochs):
'''
Get indices of dropped epochs from `epochs.drop_log`.
Parameters
----------
epochs : mne Epochs instance
Epochs to get dropped indices from.
Returns
-------
dropped_epochs : 1d numpy array
Array containing indices of dropped epochs.
'''
current_epoch = 0
dropped_epochs = list()
for info in epochs.drop_log:
if 'IGNORED' not in info:
if len(info) > 0:
dropped_epochs.append(current_epoch)
current_epoch += 1
return np.array(dropped_epochs)
[docs]@contextmanager
def silent_mne():
'''
Context manager that silences warnings from mne-python.
'''
import mne
log_level = mne.set_log_level('error', return_old_level=True)
yield
mne.set_log_level(log_level)
def _get_test_data_dir():
'''Get test data directory.'''
from borsar import __path__ as borsar_dir
return op.join(borsar_dir[0], 'data')
[docs]def download_test_data():
'''Download additional test data from dropbox.'''
import zipfile
from mne.utils import _fetch_file
# check if test data exist
data_dir = _get_test_data_dir()
check_files = ['alpha_range_clusters.hdf5', 'DiamSar-eeg-oct-6-fwd.fif',
op.join('fsaverage', 'bem', 'fsaverage-ico-5-src.fif'),
'chan_alpha_range.hdf5']
if all([op.isfile(op.join(data_dir, f)) for f in check_files]):
return
# set up paths
download_link = ('https://www.dropbox.com/sh/l4scs37524lb3pa/'
'AABCak4jORjgridWwHlwjhMHa?dl=1')
destination = op.join(data_dir, 'temp_file.zip')
# download the file
_fetch_file(download_link, destination, print_destination=True,
resume=True, timeout=30.)
# unzip and extract
# TODO - optionally extract only the missing files
zip_ref = zipfile.ZipFile(destination, 'r')
zip_ref.extractall(data_dir)
zip_ref.close()
# remove the zipfile
os.remove(destination)