Source code for borsar.utils

import os
import os.path as op

import numpy as np
from contextlib import contextmanager


[docs]def find_range(vec, ranges): ''' Find specified ranges in an ordered vector and return them as slices. Parameters ---------- vec : numpy array Vector of sorted values. ranges: list of tuples/lists | two-element list/tuple Ranges or range to be found. Returns ------- slices : slice or list of slices Slices representing the ranges. If one range was passed the output is a slice. If two or more ranges were passed the output is a list of slices. ''' assert isinstance(ranges, (list, tuple)) assert len(ranges) > 0 one_in = False if not isinstance(ranges[0], (list, tuple)) and len(ranges) == 2: one_in = True ranges = [ranges] slices = list() for rng in ranges: start, stop = [np.abs(vec - x).argmin() for x in rng] slices.append(slice(start, stop + 1)) # including last index if one_in: slices = slices[0] return slices
# - [ ] if vals is np.ndarray try to format output in the right shape
[docs]def find_index(vec, vals): ''' Find indices of values in `vec` that are closest to requested values `vals`. Parameters ---------- vec : numpy array Vector of values. vals: list of values | value Values to find closest representatives of in the `vec` vector. Returns ------- idx : numpy array of int | int Indices of `vec` values closest to `vals`. If one value was passed in `vals` then `idx` will also be one value. If two or more values were passed in `vals` the output is a numpy array of indices. ''' one_in = False if not isinstance(vals, (list, tuple, np.ndarray)): one_in = True vals = [vals] outlist = [np.abs(vec - x).argmin() for x in vals] if one_in: return outlist[0] else: return np.array(outlist)
[docs]def get_info(inst): '''Simple helper function that returns Info whatever mne object it gets.''' from mne import Info if isinstance(inst, Info): return inst else: return inst.info
[docs]def detect_overlap(segment, annot, sfreq=None): ''' Detect what percentage of given segment is overlapping with annotations. Parameters ---------- segment : list or 1d array Two-element list or array of [start, stop] values. annot : mne.Annotation of 2d numpy array Annotations or 2d array of N x 2 (start, stop) values. sfreq : float Sampling frequency (default: None). If not None segment is assumed to be given in samples. `annot` is transformed to samples using sfreq if it is of mne.Annotations type. If `annot` is np.ndarray then it is transformed to samples only if its dtype is not 'int64' or 'int32'. Returns ------- overlap : float Percentage overlap in 0 - 1 range. ''' samples = sfreq is not None # FIXME - the branching below seems overly complex # for convenience we accept mne.Annotation objects and numpy arrays: if not isinstance(annot, np.ndarray): if annot is None: return 0. # if we didn't get numpy array we assume it's mne.Annotation # and convert it to N x 2 (N x start, end) array: onset = (annot.onset if not samples else np.round(annot.onset * sfreq).astype('int')) duration = (annot.duration if not samples else np.round(annot.duration * sfreq).astype('int')) annot_arr = np.hstack([onset[:, np.newaxis], onset[:, np.newaxis] + duration[:, np.newaxis]]) else: if not samples: annot_arr = annot else: in_samples = (annot.dtype is np.dtype('int64') or annot.dtype is np.dtype('int32')) # FIXME - if not in_samples throw an error or issue a warning annot_arr = (annot if in_samples else np.round(annot * sfreq).astype('int')) # checks for boundary relationships ll_beleq = annot_arr[:, 0] <= segment[0] hh_abveq = annot_arr[:, 1] >= segment[1] # if any annot's lower edge (le) is below or equal segments lower edge # and its higher edge (he) is above segments higher edge # then annot includes segment and the coverage is 100% if (ll_beleq & hh_abveq).any(): return 1. # otherwise we perform more checks hl_abv = annot_arr[:, 1] > segment[0] segment_length = np.diff(segment)[0] overlap = 0. # if any annot's le is below or equal segments le and its he is # above segments le - the the overlap is from segments le up to annot he check = ll_beleq & hl_abv if check.any(): # there should be only one such annot (we assume non-overlapping annot) overlap += (annot_arr[:, 1][check][0] - segment[0]) / segment_length ll_abv = ~ll_beleq hh_bel = ~hh_abveq # if any annot's le is above segments le and its he is # below segments he - the the annot is within the segment # and the overlap is from annot le up to annot he # (there can be multiple such annots) check = ll_abv & hh_bel if check.any(): overlap += (np.diff(annot_arr[check]) / segment_length).sum() # the only remaining case is when annot he is above segments he # and its le is above segments le but below segments he lh_bel = annot_arr[:, 0] < segment[1] check = hh_abveq & ll_abv & lh_bel if check.any(): overlap += (segment[1] - annot_arr[:, 0][check][0]) / segment_length return overlap
# FIXME - add warnings etc. # FIXME - there should be some mne function for that, # if so - use that function (check later) def _check_tmin_tmax(raw, tmin, tmax): sfreq = raw.info['sfreq'] lowest_tmin = raw.first_samp / sfreq highest_tmax = (raw.last_samp + 1) / sfreq tmin = lowest_tmin if tmin is None or tmin < lowest_tmin else tmin tmax = highest_tmax if tmax is None or tmax > highest_tmax else tmax return tmin, tmax, sfreq
[docs]def valid_windows(raw, tmin=None, tmax=None, winlen=2., step=1.): ''' Test which moving windows overlap with annotations. Parameters ---------- raw : mne.Raw Data to use. tmin : flot | None Start time for the moving windows. Defaults to None which means start of the raw data. tmax : flot | None End time for the moving windows. Defaults to None which means end of the raw data. winlen : float Window length in seconds. Defaults to 2. step : float Window step in seconds. Defaults to 1. Returns ------- valid : boolean numpy array Whether the moving widnows overlap with annotations. Consecutive values inform whether consecutive windows overlap with any annotation. ''' annot = raw.annotations tmin, tmax, sfreq = _check_tmin_tmax(raw, tmin, tmax) step = int(round(step * sfreq)) winlen = int(round(winlen * sfreq)) tmin_smp, tmax_smp = int(round(tmin * sfreq)), int(round(tmax * sfreq)) n_windows = int((tmax_smp - tmin_smp - winlen + step) / step) valid = np.zeros(n_windows, dtype='bool') for win_idx in range(n_windows): start = tmin_smp + win_idx * step segment = [start, start + winlen] overlap = detect_overlap(segment, annot, sfreq=sfreq) valid[win_idx] = overlap == 0. return valid
[docs]def create_fake_raw(n_channels=4, n_samples=100, sfreq=125.): ''' Create fake raw signal for testing. Parameters ---------- n_channels : int, optional Number of channels in the fake raw signal. Defaults to 4. n_samples : int, optional Number of samples in the fake raw singal. Defaults to 100. sfreq : float, optional Sampling frequency of the fake raw signal. Defaults to 125. Returns ------- raw : mne.io.RawArray Created raw array. ''' import mne from string import ascii_letters ch_names = list(ascii_letters[:n_channels]) data = np.zeros((n_channels, n_samples)) info = mne.create_info(ch_names, sfreq, ch_types='eeg') return mne.io.RawArray(data, info)
[docs]def get_dropped_epochs(epochs): ''' Get indices of dropped epochs from `epochs.drop_log`. Parameters ---------- epochs : mne Epochs instance Epochs to get dropped indices from. Returns ------- dropped_epochs : 1d numpy array Array containing indices of dropped epochs. ''' current_epoch = 0 dropped_epochs = list() for info in epochs.drop_log: if 'IGNORED' not in info: if len(info) > 0: dropped_epochs.append(current_epoch) current_epoch += 1 return np.array(dropped_epochs)
[docs]@contextmanager def silent_mne(): ''' Context manager that silences warnings from mne-python. ''' import mne log_level = mne.set_log_level('error', return_old_level=True) yield mne.set_log_level(log_level)
def _get_test_data_dir(): '''Get test data directory.''' from borsar import __path__ as borsar_dir return op.join(borsar_dir[0], 'data')
[docs]def download_test_data(): '''Download additional test data from dropbox.''' import zipfile from mne.utils import _fetch_file # check if test data exist data_dir = _get_test_data_dir() check_files = ['alpha_range_clusters.hdf5', 'DiamSar-eeg-oct-6-fwd.fif', op.join('fsaverage', 'bem', 'fsaverage-ico-5-src.fif'), 'chan_alpha_range.hdf5'] if all([op.isfile(op.join(data_dir, f)) for f in check_files]): return # set up paths download_link = ('https://www.dropbox.com/sh/l4scs37524lb3pa/' 'AABCak4jORjgridWwHlwjhMHa?dl=1') destination = op.join(data_dir, 'temp_file.zip') # download the file _fetch_file(download_link, destination, print_destination=True, resume=True, timeout=30.) # unzip and extract # TODO - optionally extract only the missing files zip_ref = zipfile.ZipFile(destination, 'r') zip_ref.extractall(data_dir) zip_ref.close() # remove the zipfile os.remove(destination)