"""
YASA (Yet Another Spindle Algorithm): fast and robust detection of spindles,
slow-waves, and rapid eye movements from sleep EEG recordings.
- Author: Raphael Vallat (www.raphaelvallat.com)
- GitHub: https://github.com/raphaelvallat/yasa
- License: BSD 3-Clause License
"""
import mne
import logging
import numpy as np
import pandas as pd
from scipy import signal
from mne.filter import filter_data
from collections import OrderedDict
from scipy.interpolate import interp1d
from scipy.fftpack import next_fast_len
from sklearn.ensemble import IsolationForest
from .spectral import stft_power
from .numba import _detrend, _rms
from .io import set_log_level, is_tensorpac_installed, is_pyriemann_installed
from .others import (
moving_transform,
trimbothstd,
get_centered_indices,
sliding_window,
_merge_close,
_zerocrossings,
)
logger = logging.getLogger("yasa")
__all__ = [
"art_detect",
"spindles_detect",
"SpindlesResults",
"sw_detect",
"SWResults",
"rem_detect",
"REMResults",
"compare_detection",
]
#############################################################################
# DATA PREPROCESSING
#############################################################################
def _check_data_hypno(data, sf=None, ch_names=None, hypno=None, include=None, check_amp=True):
"""Helper functions for preprocessing of data and hypnogram."""
# 1) Extract data as a 2D NumPy array
if isinstance(data, mne.io.BaseRaw):
sf = data.info["sfreq"] # Extract sampling frequency
ch_names = data.ch_names # Extract channel names
data = data.get_data(units=dict(eeg="uV", emg="uV", eog="uV", ecg="uV"))
else:
assert sf is not None, "sf must be specified if not using MNE Raw."
if isinstance(sf, np.ndarray): # Deal with sf = array(100.) --> 100
sf = float(sf)
assert isinstance(sf, (int, float)), "sf must be int or float."
data = np.asarray(data, dtype=np.float64)
assert data.ndim in [1, 2], "data must be 1D (times) or 2D (chan, times)."
if data.ndim == 1:
# Force to 2D array: (n_chan, n_samples)
data = data[None, ...]
n_chan, n_samples = data.shape
# 2) Check channel names
if ch_names is None:
ch_names = ["CHAN" + str(i).zfill(3) for i in range(n_chan)]
else:
assert len(ch_names) == n_chan
# 3) Check hypnogram
if hypno is not None:
hypno = np.asarray(hypno, dtype=int)
assert hypno.ndim == 1, "Hypno must be one dimensional."
assert hypno.size == n_samples, "Hypno must have same size as data."
unique_hypno = np.unique(hypno)
logger.info("Number of unique values in hypno = %i", unique_hypno.size)
assert include is not None, "include cannot be None if hypno is given"
include = np.atleast_1d(np.asarray(include))
assert include.size >= 1, "`include` must have at least one element."
assert hypno.dtype.kind == include.dtype.kind, "hypno and include must have same dtype"
assert np.in1d(hypno, include).any(), (
"None of the stages specified " "in `include` are present in " "hypno."
)
# 4) Check data amplitude
logger.info("Number of samples in data = %i", n_samples)
logger.info("Sampling frequency = %.2f Hz", sf)
logger.info("Data duration = %.2f seconds", n_samples / sf)
all_ptp = np.ptp(data, axis=-1)
all_trimstd = trimbothstd(data, cut=0.05)
bad_chan = np.zeros(n_chan, dtype=bool)
for i in range(n_chan):
logger.info("Trimmed standard deviation of %s = %.4f uV" % (ch_names[i], all_trimstd[i]))
logger.info("Peak-to-peak amplitude of %s = %.4f uV" % (ch_names[i], all_ptp[i]))
if check_amp and not (0.1 < all_trimstd[i] < 1e3):
logger.error(
"Wrong data amplitude for %s "
"(trimmed STD = %.3f). Unit of data MUST be uV! "
"Channel will be skipped." % (ch_names[i], all_trimstd[i])
)
bad_chan[i] = True
# 5) Create sleep stage vector mask
if hypno is not None:
mask = np.in1d(hypno, include)
else:
mask = np.ones(n_samples, dtype=bool)
return (data, sf, ch_names, hypno, include, mask, n_chan, n_samples, bad_chan)
#############################################################################
# BASE DETECTION RESULTS CLASS
#############################################################################
class _DetectionResults(object):
"""Main class for detection results."""
def __init__(self, events, data, sf, ch_names, hypno, data_filt):
self._events = events
self._data = data
self._sf = sf
self._hypno = hypno
self._ch_names = ch_names
self._data_filt = data_filt
def _check_mask(self, mask):
assert isinstance(mask, (pd.Series, np.ndarray, list, type(None)))
n_events = self._events.shape[0]
if mask is None:
mask = np.ones(n_events, dtype="bool") # All set to True
else:
mask = np.asarray(mask)
assert mask.dtype.kind == "b", "Mask must be a boolean array."
assert mask.ndim == 1, "Mask must be one-dimensional"
assert mask.size == n_events, "Mask.size must be the number of detected events."
return mask
def summary(
self, event_type, grp_chan=False, grp_stage=False, aggfunc="mean", sort=True, mask=None
):
"""Summary"""
# Check masking
mask = self._check_mask(mask)
# Define grouping
grouper = []
if grp_stage is True and "Stage" in self._events:
grouper.append("Stage")
if grp_chan is True and "Channel" in self._events:
grouper.append("Channel")
if not len(grouper):
# Return a copy of self._events after masking, without grouping
return self._events.loc[mask, :].copy()
if event_type == "spindles":
aggdict = {
"Start": "count",
"Duration": aggfunc,
"Amplitude": aggfunc,
"RMS": aggfunc,
"AbsPower": aggfunc,
"RelPower": aggfunc,
"Frequency": aggfunc,
"Oscillations": aggfunc,
"Symmetry": aggfunc,
}
# if 'SOPhase' in self._events:
# from scipy.stats import circmean
# aggdict['SOPhase'] = lambda x: circmean(x, low=-np.pi, high=np.pi)
elif event_type == "sw":
aggdict = {
"Start": "count",
"Duration": aggfunc,
"ValNegPeak": aggfunc,
"ValPosPeak": aggfunc,
"PTP": aggfunc,
"Slope": aggfunc,
"Frequency": aggfunc,
}
if "PhaseAtSigmaPeak" in self._events:
from scipy.stats import circmean
aggdict["PhaseAtSigmaPeak"] = lambda x: circmean(x, low=-np.pi, high=np.pi)
aggdict["ndPAC"] = aggfunc
if "CooccurringSpindle" in self._events:
# We do not average "CooccurringSpindlePeak"
aggdict["CooccurringSpindle"] = aggfunc
aggdict["DistanceSpindleToSW"] = aggfunc
else: # REM
aggdict = {
"Start": "count",
"Duration": aggfunc,
"LOCAbsValPeak": aggfunc,
"ROCAbsValPeak": aggfunc,
"LOCAbsRiseSlope": aggfunc,
"ROCAbsRiseSlope": aggfunc,
"LOCAbsFallSlope": aggfunc,
"ROCAbsFallSlope": aggfunc,
}
# Apply grouping, after masking
df_grp = self._events.loc[mask, :].groupby(grouper, sort=sort, as_index=False).agg(aggdict)
df_grp = df_grp.rename(columns={"Start": "Count"})
# Calculate density (= number per min of each stage)
if self._hypno is not None and grp_stage is True:
stages = np.unique(self._events["Stage"])
dur = {}
for st in stages:
# Get duration in minutes of each stage present in dataframe
dur[st] = self._hypno[self._hypno == st].size / (60 * self._sf)
# Insert new density column in grouped dataframe after count
df_grp.insert(
loc=df_grp.columns.get_loc("Count") + 1,
column="Density",
value=df_grp.apply(lambda rw: rw["Count"] / dur[rw["Stage"]], axis=1),
)
return df_grp.set_index(grouper)
def get_mask(self):
"""get_mask"""
from yasa.others import _index_to_events
mask = np.zeros(self._data.shape, dtype=int)
for i in self._events["IdxChannel"].unique():
ev_chan = self._events[self._events["IdxChannel"] == i]
idx_ev = _index_to_events(ev_chan[["Start", "End"]].to_numpy() * self._sf)
mask[i, idx_ev] = 1
return np.squeeze(mask)
def get_sync_events(
self, center, time_before, time_after, filt=(None, None), mask=None, as_dataframe=True
):
"""Get_sync_events (not for REM, spindles & SW only)"""
from yasa.others import get_centered_indices
assert time_before >= 0
assert time_after >= 0
bef = int(self._sf * time_before)
aft = int(self._sf * time_after)
# TODO: Step size is determined by sf: 0.01 sec at 100 Hz, 0.002 sec at
# 500 Hz, 0.00390625 sec at 256 Hz. Should we add resample=100 (Hz) or step_size=0.01?
time = np.arange(-bef, aft + 1, dtype="int") / self._sf
if any(filt):
data = mne.filter.filter_data(
self._data, self._sf, l_freq=filt[0], h_freq=filt[1], method="fir", verbose=False
)
else:
data = self._data
# Apply mask
mask = self._check_mask(mask)
masked_events = self._events.loc[mask, :]
output = []
for i in masked_events["IdxChannel"].unique():
# Copy is required to merge with the stage later on
ev_chan = masked_events[masked_events["IdxChannel"] == i].copy()
ev_chan["Event"] = np.arange(ev_chan.shape[0])
peaks = (ev_chan[center] * self._sf).astype(int).to_numpy()
# Get centered indices
idx, idx_valid = get_centered_indices(data[i, :], peaks, bef, aft)
# If no good epochs are returned raise a warning
if len(idx_valid) == 0:
logger.error(
"Time before and/or time after exceed data bounds, please "
"lower the temporal window around center. Skipping channel."
)
continue
# Get data at indices and time vector
amps = data[i, idx]
if not as_dataframe:
# Output is a list (n_channels) of numpy arrays (n_events, n_times)
output.append(amps)
continue
# Convert to long-format dataframe
df_chan = pd.DataFrame(amps.T)
df_chan["Time"] = time
# Convert to long-format
df_chan = df_chan.melt(id_vars="Time", var_name="Event", value_name="Amplitude")
# Append stage
if "Stage" in masked_events:
df_chan = df_chan.merge(ev_chan[["Event", "Stage"]].iloc[idx_valid])
# Append channel name
df_chan["Channel"] = ev_chan["Channel"].iloc[0]
df_chan["IdxChannel"] = i
# Append to master dataframe
output.append(df_chan)
if as_dataframe:
output = pd.concat(output, ignore_index=True)
return output
def get_coincidence_matrix(self, scaled=True):
"""get_coincidence_matrix"""
if len(self._ch_names) < 2:
raise ValueError("At least 2 channels are required to calculate coincidence.")
mask = self.get_mask()
mask = pd.DataFrame(mask.T, columns=self._ch_names)
mask.columns.name = "Channel"
def _coincidence(x, y):
"""Calculate the (scaled) coincidence."""
coincidence = (x * y).sum()
if scaled:
# Handle division by zero error
denom = x.sum() * y.sum()
if denom == 0:
coincidence = np.nan
else:
coincidence /= denom
return coincidence
coinc_mat = mask.corr(method=_coincidence)
if not scaled:
# Otherwise diagonal values are set to 1
np.fill_diagonal(coinc_mat.values, mask.sum())
coinc_mat = coinc_mat.astype(int)
return coinc_mat
def compare_channels(self, score="f1", max_distance_sec=0):
"""
Compare detected events across channels.
See full documentation in the methods of SpindlesResults and SWResults.
"""
from itertools import product
assert score in ["f1", "precision", "recall"], f"Invalid scoring metric: {score}"
# Extract events and channel
detected = self.summary()
chan = detected["Channel"].unique()
# Get indices of start in deciseconds, rounding to nearest deciseconds (100 ms).
# This is needed for three reasons:
# 1. Speed up the for loop
# 2. Avoid memory error in yasa.compare_detection
# 3. Make sure that max_distance works even when self and other have different sf.
# TODO: Only the Start of the event is currently supported. Add more flexibility?
detected["Start"] = (detected["Start"] * 10).round().astype(int)
max_distance = int(10 * max_distance_sec)
# Initialize output dataframe / dict
scores = pd.DataFrame(index=chan, columns=chan, dtype=float)
scores.index.name = "Channel"
scores.columns.name = "Channel"
pairs = list(product(chan, repeat=2))
# Loop across pair of channels
for c_index, c_col in pairs:
idx_chan1 = detected[detected["Channel"] == c_index]["Start"]
idx_chan2 = detected[detected["Channel"] == c_col]["Start"]
# DANGER: Note how we invert idx_chan2 and idx_chan1 here. This is because
# idx_chan1 (the index of the dataframe) should be the ground-truth.
res = compare_detection(idx_chan2, idx_chan1, max_distance)
scores.loc[c_index, c_col] = res[score]
return scores
def compare_detection(self, other, max_distance_sec=0, other_is_groundtruth=True):
"""
Compare detected events between two detection methods, or against a ground-truth scoring.
See full documentation in the methods of SpindlesResults and SWResults.
"""
detected = self.summary()
if isinstance(other, (SpindlesResults, SWResults, REMResults)):
groundtruth = other.summary()
elif isinstance(other, pd.DataFrame):
assert "Start" in other.columns
assert "Channel" in other.columns
groundtruth = other[["Start", "Channel"]].copy()
else:
raise ValueError(
f"Invalid argument other: {other}. It must be a YASA detection output or a Pandas "
f"DataFrame with the columns Start and Channels"
)
# Get indices of start in deciseconds, rounding to nearest deciseconds (100 ms).
# This is needed for three reasons:
# 1. Speed up the for loop
# 2. Avoid memory error in yasa.compare_detection
# 3. Make sure that max_distance works even when self and other have different sf.
detected["Start"] = (detected["Start"] * 10).round().astype(int)
groundtruth["Start"] = (groundtruth["Start"] * 10).round().astype(int)
max_distance = int(10 * max_distance_sec)
# Find channels that are present in both self and other
chan_detected = detected["Channel"].unique()
chan_groundtruth = groundtruth["Channel"].unique()
chan_both = np.intersect1d(chan_detected, chan_groundtruth) # Sort
if not len(chan_both):
raise ValueError(
f"No intersecting channel between self and other:\n"
f"{chan_detected}\n{chan_groundtruth}"
)
# The output is a pandas.DataFrame (n_chan, n_metrics).
scores = pd.DataFrame(
index=chan_both, columns=["precision", "recall", "f1", "n_self", "n_other"], dtype=float
)
scores.index.name = "Channel"
# Loop on each channel
for c_index in chan_both:
idx_detected = detected[detected["Channel"] == c_index]["Start"]
idx_groundtruth = groundtruth[groundtruth["Channel"] == c_index]["Start"]
if other_is_groundtruth:
res = compare_detection(idx_detected, idx_groundtruth, max_distance)
else:
res = compare_detection(idx_groundtruth, idx_detected, max_distance)
scores.loc[c_index, "precision"] = res["precision"]
scores.loc[c_index, "recall"] = res["recall"]
scores.loc[c_index, "f1"] = res["f1"]
scores.loc[c_index, "n_self"] = len(idx_detected)
scores.loc[c_index, "n_other"] = len(idx_groundtruth)
scores["n_self"] = scores["n_self"].astype(int)
scores["n_other"] = scores["n_other"].astype(int)
return scores
def plot_average(
self,
event_type,
center="Peak",
hue="Channel",
time_before=1,
time_after=1,
filt=(None, None),
mask=None,
figsize=(6, 4.5),
**kwargs,
):
"""Plot the average event (not for REM, spindles & SW only)"""
import seaborn as sns
import matplotlib.pyplot as plt
df_sync = self.get_sync_events(
center=center, time_before=time_before, time_after=time_after, filt=filt, mask=mask
)
assert not df_sync.empty, "Could not calculate event-locked data."
assert hue in ["Stage", "Channel"], "hue must be 'Channel' or 'Stage'"
assert hue in df_sync.columns, "%s is not present in data." % hue
if event_type == "spindles":
title = "Average spindle"
else: # "sw":
title = "Average SW"
# Start figure
fig, ax = plt.subplots(1, 1, figsize=figsize)
sns.lineplot(data=df_sync, x="Time", y="Amplitude", hue=hue, ax=ax, **kwargs)
# ax.legend(frameon=False, loc='lower right')
ax.set_xlim(df_sync["Time"].min(), df_sync["Time"].max())
ax.set_title(title)
ax.set_xlabel("Time (sec)")
ax.set_ylabel("Amplitude (uV)")
return ax
def plot_detection(self):
"""Plot an overlay of the detected events on the signal."""
import matplotlib.pyplot as plt
import ipywidgets as ipy
# Define mask
sf = self._sf
win_size = 10
mask = self.get_mask()
highlight = self._data * mask
highlight = np.where(highlight == 0, np.nan, highlight)
highlight_filt = self._data_filt * mask
highlight_filt = np.where(highlight_filt == 0, np.nan, highlight_filt)
n_epochs = int((self._data.shape[-1] / sf) / win_size)
times = np.arange(self._data.shape[-1]) / sf
# Define xlim and xrange
xlim = [0, win_size]
xrng = np.arange(xlim[0] * sf, (xlim[1] * sf + 1), dtype=int)
# Plot
fig, ax = plt.subplots(figsize=(12, 4))
plt.plot(times[xrng], self._data[0, xrng], "k", lw=1)
plt.plot(times[xrng], highlight[0, xrng], "indianred")
plt.xlabel("Time (seconds)")
plt.ylabel("Amplitude (uV)")
fig.canvas.header_visible = False
fig.tight_layout()
# WIDGETS
layout = ipy.Layout(width="50%", justify_content="center", align_items="center")
sl_ep = ipy.IntSlider(
min=0,
max=n_epochs,
step=1,
value=0,
layout=layout,
description="Epoch:",
)
sl_amp = ipy.IntSlider(
min=25,
max=500,
step=25,
value=150,
layout=layout,
orientation="horizontal",
description="Amplitude:",
)
dd_ch = ipy.Dropdown(
options=self._ch_names, value=self._ch_names[0], description="Channel:"
)
dd_win = ipy.Dropdown(
options=[1, 5, 10, 30, 60],
value=win_size,
description="Window size:",
)
dd_check = ipy.Checkbox(
value=False,
description="Filtered",
)
def update(epoch, amplitude, channel, win_size, filt):
"""Update plot."""
n_epochs = int((self._data.shape[-1] / sf) / win_size)
sl_ep.max = n_epochs
xlim = [epoch * win_size, (epoch + 1) * win_size]
xrng = np.arange(xlim[0] * sf, (xlim[1] * sf), dtype=int)
# Check if filtered
data = self._data if not filt else self._data_filt
overlay = highlight if not filt else highlight_filt
try:
ax.lines[0].set_data(times[xrng], data[dd_ch.index, xrng])
ax.lines[1].set_data(times[xrng], overlay[dd_ch.index, xrng])
ax.set_xlim(xlim)
except IndexError:
pass
ax.set_ylim([-amplitude, amplitude])
return ipy.interact(
update, epoch=sl_ep, amplitude=sl_amp, channel=dd_ch, win_size=dd_win, filt=dd_check
)
#############################################################################
# SPINDLES DETECTION
#############################################################################
[docs]def spindles_detect(
data,
sf=None,
ch_names=None,
hypno=None,
include=(1, 2, 3),
freq_sp=(12, 15),
freq_broad=(1, 30),
duration=(0.5, 2),
min_distance=500,
thresh={"rel_pow": 0.2, "corr": 0.65, "rms": 1.5},
multi_only=False,
remove_outliers=False,
verbose=False,
):
"""Spindles detection.
Parameters
----------
data : array_like
Single or multi-channel data. Unit must be uV and shape (n_samples) or
(n_chan, n_samples). Can also be a :py:class:`mne.io.BaseRaw`,
in which case ``data``, ``sf``, and ``ch_names`` will be automatically
extracted, and ``data`` will also be automatically converted from
Volts (MNE) to micro-Volts (YASA).
sf : float
Sampling frequency of the data in Hz.
Can be omitted if ``data`` is a :py:class:`mne.io.BaseRaw`.
.. tip:: If the detection is taking too long, make sure to downsample
your data to 100 Hz (or 128 Hz). For more details, please refer to
:py:func:`mne.filter.resample`.
ch_names : list of str
Channel names. Can be omitted if ``data`` is a
:py:class:`mne.io.BaseRaw`.
hypno : array_like
Sleep stage (hypnogram). If the hypnogram is loaded, the
detection will only be applied to the value defined in
``include`` (default = N1 + N2 + N3 sleep).
The hypnogram must have the same number of samples as ``data``.
To upsample your hypnogram, please refer to
:py:func:`yasa.hypno_upsample_to_data`.
.. note::
The default hypnogram format in YASA is a 1D integer
vector where:
- -2 = Unscored
- -1 = Artefact / Movement
- 0 = Wake
- 1 = N1 sleep
- 2 = N2 sleep
- 3 = N3 sleep
- 4 = REM sleep
include : tuple, list or int
Values in ``hypno`` that will be included in the mask. The default is
(1, 2, 3), meaning that the detection is applied on N1, N2 and N3
sleep. This has no effect when ``hypno`` is None.
freq_sp : tuple or list
Spindles frequency range. Default is 12 to 15 Hz. Please note that YASA
uses a FIR filter (implemented in MNE) with a 1.5Hz transition band,
which means that for `freq_sp = (12, 15 Hz)`, the -6 dB points are
located at 11.25 and 15.75 Hz.
freq_broad : tuple or list
Broad band frequency range. Default is 1 to 30 Hz.
duration : tuple or list
The minimum and maximum duration of the spindles.
Default is 0.5 to 2 seconds.
min_distance : int
If two spindles are closer than ``min_distance`` (in ms), they are
merged into a single spindles. Default is 500 ms.
thresh : dict
Detection thresholds:
* ``'rel_pow'``: Relative power (= power ratio freq_sp / freq_broad).
* ``'corr'``: Moving correlation between original signal and
sigma-filtered signal.
* ``'rms'``: Number of standard deviations above the mean of a moving
root mean square of sigma-filtered signal.
You can disable one or more threshold by putting ``None`` instead:
.. code-block:: python
thresh = {'rel_pow': None, 'corr': 0.65, 'rms': 1.5}
thresh = {'rel_pow': None, 'corr': None, 'rms': 3}
multi_only : boolean
Define the behavior of the multi-channel detection. If True, only
spindles that are present on at least two channels are kept. If False,
no selection is applied and the output is just a concatenation of the
single-channel detection dataframe. Default is False.
remove_outliers : boolean
If True, YASA will automatically detect and remove outliers spindles
using :py:class:`sklearn.ensemble.IsolationForest`.
The outliers detection is performed on all the spindles
parameters with the exception of the ``Start``, ``Peak``, ``End``,
``Stage``, and ``SOPhase`` columns.
YASA uses a random seed (42) to ensure reproducible results.
Note that this step will only be applied if there are more than 50
detected spindles in the first place. Default to False.
verbose : bool or str
Verbose level. Default (False) will only print warning and error
messages. The logging levels are 'debug', 'info', 'warning', 'error',
and 'critical'. For most users the choice is between 'info'
(or ``verbose=True``) and warning (``verbose=False``).
.. versionadded:: 0.2.0
Returns
-------
sp : :py:class:`yasa.SpindlesResults`
To get the full detection dataframe, use:
>>> sp = spindles_detect(...)
>>> sp.summary()
This will give a :py:class:`pandas.DataFrame` where each row is a
detected spindle and each column is a parameter (= feature or property)
of this spindle. To get the average spindles parameters per channel and
sleep stage:
>>> sp.summary(grp_chan=True, grp_stage=True)
Notes
-----
The parameters that are calculated for each spindle are:
* ``'Start'``: Start time of the spindle, in seconds from the beginning of
data.
* ``'Peak'``: Time at the most prominent spindle peak (in seconds).
* ``'End'`` : End time (in seconds).
* ``'Duration'``: Duration (in seconds)
* ``'Amplitude'``: Peak-to-peak amplitude of the (detrended) spindle in
the raw data (in µV).
* ``'RMS'``: Root-mean-square (in µV)
* ``'AbsPower'``: Median absolute power (in log10 µV^2),
calculated from the Hilbert-transform of the ``freq_sp`` filtered signal.
* ``'RelPower'``: Median relative power of the ``freq_sp`` band in spindle
calculated from a short-term fourier transform and expressed as a
proportion of the total power in ``freq_broad``.
* ``'Frequency'``: Median instantaneous frequency of spindle (in Hz),
derived from an Hilbert transform of the ``freq_sp`` filtered signal.
* ``'Oscillations'``: Number of oscillations (= number of positive peaks
in spindle.)
* ``'Symmetry'``: Location of the most prominent peak of spindle,
normalized from 0 (start) to 1 (end). Ideally this value should be close
to 0.5, indicating that the most prominent peak is halfway through the
spindle.
* ``'Stage'`` : Sleep stage during which spindle occured, if ``hypno``
was provided.
All parameters are calculated from the broadband-filtered EEG
(frequency range defined in ``freq_broad``).
For better results, apply this detection only on artefact-free NREM sleep.
.. warning::
A critical bug was fixed in YASA 0.6.1, in which the number of detected spindles could
vary drastically depending on the sampling frequency of the data. Please make sure to check
any results obtained with this function prior to the 0.6.1 release.
References
----------
The sleep spindles detection algorithm is based on:
* Lacourse, K., Delfrate, J., Beaudry, J., Peppard, P., & Warby, S. C.
(2018). `A sleep spindle detection algorithm that emulates human expert
spindle scoring. <https://doi.org/10.1016/j.jneumeth.2018.08.014>`_
Journal of Neuroscience Methods.
Examples
--------
For a walkthrough of the spindles detection, please refer to the following
Jupyter notebooks:
https://github.com/raphaelvallat/yasa/blob/master/notebooks/01_spindles_detection.ipynb
https://github.com/raphaelvallat/yasa/blob/master/notebooks/02_spindles_detection_multi.ipynb
https://github.com/raphaelvallat/yasa/blob/master/notebooks/03_spindles_detection_NREM_only.ipynb
https://github.com/raphaelvallat/yasa/blob/master/notebooks/04_spindles_slow_fast.ipynb
"""
set_log_level(verbose)
(data, sf, ch_names, hypno, include, mask, n_chan, n_samples, bad_chan) = _check_data_hypno(
data, sf, ch_names, hypno, include
)
# If all channels are bad
if sum(bad_chan) == n_chan:
logger.warning("All channels have bad amplitude. Returning None.")
return None
# Check detection thresholds
if "rel_pow" not in thresh.keys():
thresh["rel_pow"] = 0.20
if "corr" not in thresh.keys():
thresh["corr"] = 0.65
if "rms" not in thresh.keys():
thresh["rms"] = 1.5
do_rel_pow = thresh["rel_pow"] not in [None, "none", "None"]
do_corr = thresh["corr"] not in [None, "none", "None"]
do_rms = thresh["rms"] not in [None, "none", "None"]
n_thresh = sum([do_rel_pow, do_corr, do_rms])
assert n_thresh >= 1, "At least one threshold must be defined."
# Filtering
nfast = next_fast_len(n_samples)
# 1) Broadband bandpass filter (optional -- careful of lower freq for PAC)
data_broad = filter_data(data, sf, freq_broad[0], freq_broad[1], method="fir", verbose=0)
# 2) Sigma bandpass filter
# The width of the transition band is set to 1.5 Hz on each side,
# meaning that for freq_sp = (12, 15 Hz), the -6 dB points are located at
# 11.25 and 15.75 Hz.
data_sigma = filter_data(
data,
sf,
freq_sp[0],
freq_sp[1],
l_trans_bandwidth=1.5,
h_trans_bandwidth=1.5,
method="fir",
verbose=0,
)
# Hilbert power (to define the instantaneous frequency / power)
analytic = signal.hilbert(data_sigma, N=nfast)[:, :n_samples]
inst_phase = np.angle(analytic)
inst_pow = np.square(np.abs(analytic))
inst_freq = sf / (2 * np.pi) * np.diff(inst_phase, axis=-1)
# Extract the SO signal for coupling
# if coupling:
# # We need to use the original (non-filtered data)
# data_so = filter_data(data, sf, freq_so[0], freq_so[1], method='fir',
# l_trans_bandwidth=0.1, h_trans_bandwidth=0.1,
# verbose=0)
# # Now extract the instantaneous phase using Hilbert transform
# so_phase = np.angle(signal.hilbert(data_so, N=nfast)[:, :n_samples])
# Initialize empty output dataframe
df = pd.DataFrame()
for i in range(n_chan):
# ####################################################################
# START SINGLE CHANNEL DETECTION
# ####################################################################
# First, skip channels with bad data amplitude
if bad_chan[i]:
continue
# Compute the pointwise relative power using interpolated STFT
# Here we use a step of 200 ms to speed up the computation.
# Note that even if the threshold is None we still need to calculate it
# for the individual spindles parameter (RelPow).
f, t, Sxx = stft_power(
data_broad[i, :], sf, window=2, step=0.2, band=freq_broad, interp=False, norm=True
)
idx_sigma = np.logical_and(f >= freq_sp[0], f <= freq_sp[1])
rel_pow = Sxx[idx_sigma].sum(0)
# Let's interpolate `rel_pow` to get one value per sample
# Note that we could also have use the `interp=True` in the
# `stft_power` function, however 2D interpolation is much slower than
# 1D interpolation.
func = interp1d(t, rel_pow, kind="cubic", bounds_error=False, fill_value=0)
t = np.arange(n_samples) / sf
rel_pow = func(t)
if do_corr:
_, mcorr = moving_transform(
x=data_sigma[i, :],
y=data_broad[i, :],
sf=sf,
window=0.3,
step=0.1,
method="corr",
interp=True,
)
if do_rms:
_, mrms = moving_transform(
x=data_sigma[i, :], sf=sf, window=0.3, step=0.1, method="rms", interp=True
)
# Let's define the thresholds
if hypno is None:
thresh_rms = mrms.mean() + thresh["rms"] * trimbothstd(mrms, cut=0.10)
else:
thresh_rms = mrms[mask].mean() + thresh["rms"] * trimbothstd(mrms[mask], cut=0.10)
# Avoid too high threshold caused by Artefacts / Motion during Wake
thresh_rms = min(thresh_rms, 10)
logger.info("Moving RMS threshold = %.3f", thresh_rms)
# Boolean vector of supra-threshold indices
idx_sum = np.zeros(n_samples)
if do_rel_pow:
idx_rel_pow = (rel_pow >= thresh["rel_pow"]).astype(int)
idx_sum += idx_rel_pow
logger.info("N supra-theshold relative power = %i", idx_rel_pow.sum())
if do_corr:
idx_mcorr = (mcorr >= thresh["corr"]).astype(int)
idx_sum += idx_mcorr
logger.info("N supra-theshold moving corr = %i", idx_mcorr.sum())
if do_rms:
idx_mrms = (mrms >= thresh_rms).astype(int)
idx_sum += idx_mrms
logger.info("N supra-theshold moving RMS = %i", idx_mrms.sum())
# Make sure that we do not detect spindles outside mask
if hypno is not None:
idx_sum[~mask] = 0
# The detection using the three thresholds tends to underestimate the
# real duration of the spindle. To overcome this, we compute a soft
# threshold by smoothing the idx_sum vector with a ~100 ms window.
# Sampling frequency = 100 Hz --> w = 10 samples
# Sampling frequecy = 256 Hz --> w = 25 samples = 97 ms
w = int(0.1 * sf)
# Critical bugfix March 2022, see https://github.com/raphaelvallat/yasa/pull/55
idx_sum = np.convolve(idx_sum, np.ones(w), mode="same") / w
# And we then find indices that are strictly greater than 2, i.e. we
# find the 'true' beginning and 'true' end of the events by finding
# where at least two out of the three treshold were crossed.
where_sp = np.where(idx_sum > (n_thresh - 1))[0]
# If no events are found, skip to next channel
if not len(where_sp):
logger.warning("No spindle were found in channel %s.", ch_names[i])
continue
# Merge events that are too close
if min_distance is not None and min_distance > 0:
where_sp = _merge_close(where_sp, min_distance, sf)
# Extract start, end, and duration of each spindle
sp = np.split(where_sp, np.where(np.diff(where_sp) != 1)[0] + 1)
idx_start_end = np.array([[k[0], k[-1]] for k in sp]) / sf
sp_start, sp_end = idx_start_end.T
sp_dur = sp_end - sp_start
# Find events with bad duration
good_dur = np.logical_and(sp_dur > duration[0], sp_dur < duration[1])
# If no events of good duration are found, skip to next channel
if all(~good_dur):
logger.warning("No spindle were found in channel %s.", ch_names[i])
continue
# Initialize empty variables
sp_amp = np.zeros(len(sp))
sp_freq = np.zeros(len(sp))
sp_rms = np.zeros(len(sp))
sp_osc = np.zeros(len(sp))
sp_sym = np.zeros(len(sp))
sp_abs = np.zeros(len(sp))
sp_rel = np.zeros(len(sp))
sp_sta = np.zeros(len(sp))
sp_pro = np.zeros(len(sp))
# sp_cou = np.zeros(len(sp))
# Number of oscillations (number of peaks separated by at least 60 ms)
# --> 60 ms because 1000 ms / 16 Hz = 62.5 m, in other words, at 16 Hz,
# peaks are separated by 62.5 ms. At 11 Hz peaks are separated by 90 ms
distance = 60 * sf / 1000
for j in np.arange(len(sp))[good_dur]:
# Important: detrend the signal to avoid wrong PTP amplitude
sp_x = np.arange(data_broad[i, sp[j]].size, dtype=np.float64)
sp_det = _detrend(sp_x, data_broad[i, sp[j]])
# sp_det = signal.detrend(data_broad[i, sp[i]], type='linear')
sp_amp[j] = np.ptp(sp_det) # Peak-to-peak amplitude
sp_rms[j] = _rms(sp_det) # Root mean square
sp_rel[j] = np.median(rel_pow[sp[j]]) # Median relative power
# Hilbert-based instantaneous properties
sp_inst_freq = inst_freq[i, sp[j]]
sp_inst_pow = inst_pow[i, sp[j]]
sp_abs[j] = np.median(np.log10(sp_inst_pow[sp_inst_pow > 0]))
sp_freq[j] = np.median(sp_inst_freq[sp_inst_freq > 0])
# Number of oscillations
peaks, peaks_params = signal.find_peaks(
sp_det, distance=distance, prominence=(None, None)
)
sp_osc[j] = len(peaks)
# For frequency and amplitude, we can also optionally use these
# faster alternatives. If we use them, we do not need to compute
# the Hilbert transform of the filtered signal.
# sp_freq[j] = sf / np.mean(np.diff(peaks))
# sp_amp[j] = peaks_params['prominences'].max()
# Peak location & symmetry index
# pk is expressed in sample since the beginning of the spindle
pk = peaks[peaks_params["prominences"].argmax()]
sp_pro[j] = sp_start[j] + pk / sf
sp_sym[j] = pk / sp_det.size
# SO-spindles coupling
# if coupling:
# sp_cou[j] = so_phase[i, sp[j]][pk]
# Sleep stage
if hypno is not None:
sp_sta[j] = hypno[sp[j]][0]
# Create a dataframe
sp_params = {
"Start": sp_start,
"Peak": sp_pro,
"End": sp_end,
"Duration": sp_dur,
"Amplitude": sp_amp,
"RMS": sp_rms,
"AbsPower": sp_abs,
"RelPower": sp_rel,
"Frequency": sp_freq,
"Oscillations": sp_osc,
"Symmetry": sp_sym,
# 'SOPhase': sp_cou,
"Stage": sp_sta,
}
df_chan = pd.DataFrame(sp_params)[good_dur]
# We need at least 50 detected spindles to apply the Isolation Forest.
if remove_outliers and df_chan.shape[0] >= 50:
col_keep = [
"Duration",
"Amplitude",
"RMS",
"AbsPower",
"RelPower",
"Frequency",
"Oscillations",
"Symmetry",
]
ilf = IsolationForest(
contamination="auto", max_samples="auto", verbose=0, random_state=42
)
good = ilf.fit_predict(df_chan[col_keep])
good[good == -1] = 0
logger.info(
"%i outliers were removed in channel %s." % ((good == 0).sum(), ch_names[i])
)
# Remove outliers from DataFrame
df_chan = df_chan[good.astype(bool)]
logger.info("%i spindles were found in channel %s." % (df_chan.shape[0], ch_names[i]))
# ####################################################################
# END SINGLE CHANNEL DETECTION
# ####################################################################
df_chan["Channel"] = ch_names[i]
df_chan["IdxChannel"] = i
df = pd.concat([df, df_chan], axis=0, ignore_index=True)
# If no spindles were detected, return None
if df.empty:
logger.warning("No spindles were found in data. Returning None.")
return None
# Remove useless columns
to_drop = []
if hypno is None:
to_drop.append("Stage")
else:
df["Stage"] = df["Stage"].astype(int)
# if not coupling:
# to_drop.append('SOPhase')
if len(to_drop):
df = df.drop(columns=to_drop)
# Find spindles that are present on at least two channels
if multi_only and df["Channel"].nunique() > 1:
# We round to the nearest second
idx_good = np.logical_or(
df["Start"].round(0).duplicated(keep=False), df["End"].round(0).duplicated(keep=False)
).to_list()
df = df[idx_good].reset_index(drop=True)
return SpindlesResults(
events=df, data=data, sf=sf, ch_names=ch_names, hypno=hypno, data_filt=data_sigma
)
[docs]class SpindlesResults(_DetectionResults):
"""Output class for spindles detection.
Attributes
----------
_events : :py:class:`pandas.DataFrame`
Output detection dataframe
_data : array_like
Original EEG data of shape *(n_chan, n_samples)*.
_data_filt : array_like
Sigma-filtered EEG data of shape *(n_chan, n_samples)*.
_sf : float
Sampling frequency of data.
_ch_names : list
Channel names.
_hypno : array_like or None
Sleep staging vector.
"""
[docs] def __init__(self, events, data, sf, ch_names, hypno, data_filt):
super().__init__(events, data, sf, ch_names, hypno, data_filt)
[docs] def summary(self, grp_chan=False, grp_stage=False, mask=None, aggfunc="mean", sort=True):
"""Return a summary of the spindles detection, optionally grouped
across channels and/or stage.
Parameters
----------
grp_chan : bool
If True, group by channel (for multi-channels detection only).
grp_stage : bool
If True, group by sleep stage (provided that an hypnogram was
used).
mask : array_like or None
Custom boolean mask. Only the detected events for which mask is True will be
included in the summary dataframe. Default is None, i.e. no masking
(all events are included).
aggfunc : str or function
Averaging function (e.g. ``'mean'`` or ``'median'``).
sort : bool
If True, sort group keys when grouping.
"""
return super().summary(
event_type="spindles",
grp_chan=grp_chan,
grp_stage=grp_stage,
aggfunc=aggfunc,
sort=sort,
mask=mask,
)
[docs] def get_coincidence_matrix(self, scaled=True):
"""Return the (scaled) coincidence matrix.
Parameters
----------
scaled : bool
If True (default), the coincidence matrix is scaled (see Notes).
Returns
-------
coincidence : pd.DataFrame
A symmetric matrix with the (scaled) coincidence values.
Notes
-----
Do spindles occur at the same time? One way to measure this is to
calculate the coincidence matrix, which gives, for each pair of
channel, the number of samples that were marked as a spindle in both
channels. The output is a symmetric matrix, in which the diagonal is
simply the number of data points that were marked as a spindle in the
channel.
The coincidence matrix can be scaled (default) by dividing the output
by the product of the sum of each individual binary mask, as shown in
the example below. It can then be used to define functional
networks or quickly find outlier channels.
Examples
--------
Calculate the coincidence of two binary mask:
>>> import numpy as np
>>> x = np.array([0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1])
>>> y = np.array([0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1])
>>> x * y
array([0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1])
>>> (x * y).sum() # Unscaled coincidence
3
>>> (x * y).sum() / (x.sum() * y.sum()) # Scaled coincidence
0.12
References
----------
- https://github.com/Mark-Kramer/Sleep-Networks-2021
"""
return super().get_coincidence_matrix(scaled=scaled)
[docs] def compare_channels(self, score="f1", max_distance_sec=0):
"""
Compare detected spindles across channels.
This is a wrapper around the :py:func:`yasa.compare_detection` function. Please
refer to the documentation of this function for more details.
Parameters
----------
score : str
The performance metric to compute. Accepted values are "precision", "recall"
(aka sensitivity) and "f1" (default). The F1-score is the harmonic mean of precision
and recall, and is usually the preferred metric to evaluate the agreement between
two channels. All three metrics are bounded by 0 and 1, where 1 indicates perfect
agreement.
max_distance_sec : float
The maximum distance between spindles, in seconds, to consider as the same event.
.. warning:: To reduce computation cost, YASA rounds the start time of each spindle to
the nearest decisecond (= 100 ms). This means that the lowest possible resolution
is 100 ms, regardless of the sampling frequency of the data. Two spindles starting
at 500 ms and 540 ms on their respective channels will therefore always be
considered the same event, even when max_distance_sec=0.
Returns
-------
scores : :py:class:`pandas.DataFrame`
A Pandas DataFrame with the output scores, of shape (n_chan, n_chan).
Notes
-----
Some use cases of this function:
1. What proportion of spindles detected in one channel are also detected on
another channel (if using ``score="recall"``).
2. What is the overall agreement in the detected events between channels?
3. Is the agreement better in channels that are close to one another?
"""
return super().compare_channels(score, max_distance_sec)
[docs] def compare_detection(self, other, max_distance_sec=0, other_is_groundtruth=True):
"""
Compare the detected spindles against either another YASA detection or against custom
annotations (e.g. ground-truth human scoring).
This function is a wrapper around the :py:func:`yasa.compare_detection` function. Please
refer to the documentation of this function for more details.
Parameters
----------
other : dataframe or detection results
This can be either a) the output of another YASA detection, for example if you want to
test the impact of tweaking some parameters on the detected events or b) a pandas
DataFrame with custom annotations, obtained by another detection method outside
of YASA, or with manual labelling. If b), the dataframe must contain the "Start" and
"Channel" columns, with the start of each event in seconds from the beginning
of the recording and the channel name, respectively. The channel names should match
the output of the summary() method.
max_distance_sec : float
The maximum distance between spindles, in seconds, to consider as the same event.
.. warning:: To reduce computation cost, YASA rounds the start time of each spindle to
the nearest decisecond (= 100 ms). This means that the lowest possible resolution
is 100 ms, regardless of the sampling frequency of the data.
other_is_groundtruth : bool
If True (default), ``other`` will be considered as the ground-truth scoring. If False,
the current detection will be considered as the ground-truth, and the precision and
recall scores will be inverted. This parameter has no effect on the F1-score.
.. note:: when ``other`` is the ground-truth (default), the recall score is the
fraction of events in other that were succesfully detected by the current
detection, and the precision score is the proportion of detected events by the
current detection that are also present in other.
Returns
-------
scores : :py:class:`pandas.DataFrame`
A Pandas DataFrame with the channel names as index, and the following columns
* ``precision``: Precision score, aka positive predictive value
* ``recall``: Recall score, aka sensitivity
* ``f1``: F1-score
* ``n_self``: Number of detected events in ``self`` (current method).
* ``n_other``: Number of detected events in ``other``.
Notes
-----
Some use cases of this function:
1. How well does YASA events detection perform against ground-truth human annotations?
2. If I change the threshold(s) of the events detection, do the detected events match
those obtained with the default parameters?
3. Which detection thresholds give the highest agreement with the ground-truth scoring?
"""
return super().compare_detection(other, max_distance_sec, other_is_groundtruth)
[docs] def get_mask(self):
"""
Return a boolean array indicating for each sample in data if this
sample is part of a detected event (True) or not (False).
"""
return super().get_mask()
[docs] def get_sync_events(
self,
center="Peak",
time_before=1,
time_after=1,
filt=(None, None),
mask=None,
as_dataframe=True,
):
"""
Return the raw or filtered data of each detected event after
centering to a specific timepoint.
Parameters
----------
center : str
Landmark of the event to synchronize the timing on.
Default is to use the center peak of the spindles.
time_before : float
Time (in seconds) before ``center``.
time_after : float
Time (in seconds) after ``center``.
filt : tuple
Optional filtering to apply to data. For instance, ``filt=(1, 30)``
will apply a 1 to 30 Hz bandpass filter, and ``filt=(None, 40)``
will apply a 40 Hz lowpass filter. Filtering is done using default
parameters in the :py:func:`mne.filter.filter_data` function.
mask : array_like or None
Custom boolean mask. Only the detected events for which mask is True will be
included. Default is None, i.e. no masking (all events are included).
as_dataframe : boolean
If True (default), returns a long-format pandas dataframe. If False, returns a list of
numpy arrays. Each element of the list a unique channel, and the shape of the numpy
arrays within the list is (n_events, n_times).
Returns
-------
df_sync : :py:class:`pandas.DataFrame`
Ouput long-format dataframe (if ``as_dataframe=True``)::
'Event' : Event number
'Time' : Timing of the events (in seconds)
'Amplitude' : Raw or filtered data for event
'Channel' : Channel
'IdxChannel' : Index of channel in data
'Stage': Sleep stage in which the events occured (if available)
"""
return super().get_sync_events(
center=center,
time_before=time_before,
time_after=time_after,
filt=filt,
mask=mask,
as_dataframe=as_dataframe,
)
[docs] def plot_average(
self,
center="Peak",
hue="Channel",
time_before=1,
time_after=1,
filt=(None, None),
mask=None,
figsize=(6, 4.5),
**kwargs,
):
"""
Plot the average spindle.
Parameters
----------
center : str
Landmark of the event to synchronize the timing on.
Default is to use the most prominent peak of the spindle.
hue : str
Grouping variable that will produce lines with different colors.
Can be either 'Channel' or 'Stage'.
time_before : float
Time (in seconds) before ``center``.
time_after : float
Time (in seconds) after ``center``.
filt : tuple
Optional filtering to apply to data. For instance, ``filt=(12, 16)``
will apply a 12 to 16 Hz bandpass filter, and ``filt=(None, 40)``
will apply a 40 Hz lowpass filter. Filtering is done using the default
parameters in the :py:func:`mne.filter.filter_data` function.
mask : array_like or None
Custom boolean mask. Only the detected events for which mask is True will be
plotted. Default is None, i.e. no masking (all events are included).
figsize : tuple
Figure size in inches.
**kwargs : dict
Optional argument that are passed to :py:func:`seaborn.lineplot`.
"""
return super().plot_average(
event_type="spindles",
center=center,
hue=hue,
time_before=time_before,
time_after=time_after,
filt=filt,
mask=mask,
figsize=figsize,
**kwargs,
)
[docs] def plot_detection(self):
"""Plot an overlay of the detected spindles on the EEG signal.
This only works in Jupyter and it requires the ipywidgets
(https://ipywidgets.readthedocs.io/en/latest/) package.
To activate the interactive mode, make sure to run:
>>> %matplotlib widget
.. versionadded:: 0.4.0
"""
return super().plot_detection()
#############################################################################
# SLOW-WAVES DETECTION
#############################################################################
[docs]def sw_detect(
data,
sf=None,
ch_names=None,
hypno=None,
include=(2, 3),
freq_sw=(0.3, 1.5),
dur_neg=(0.3, 1.5),
dur_pos=(0.1, 1),
amp_neg=(40, 200),
amp_pos=(10, 150),
amp_ptp=(75, 350),
coupling=False,
coupling_params={"freq_sp": (12, 16), "time": 1, "p": 0.05},
remove_outliers=False,
verbose=False,
):
"""Slow-waves detection.
Parameters
----------
data : array_like
Single or multi-channel data. Unit must be uV and shape (n_samples) or
(n_chan, n_samples). Can also be a :py:class:`mne.io.BaseRaw`,
in which case ``data``, ``sf``, and ``ch_names`` will be automatically
extracted, and ``data`` will also be automatically converted from
Volts (MNE) to micro-Volts (YASA).
sf : float
Sampling frequency of the data in Hz.
Can be omitted if ``data`` is a :py:class:`mne.io.BaseRaw`.
.. tip:: If the detection is taking too long, make sure to downsample
your data to 100 Hz (or 128 Hz). For more details, please refer to
:py:func:`mne.filter.resample`.
ch_names : list of str
Channel names. Can be omitted if ``data`` is a
:py:class:`mne.io.BaseRaw`.
hypno : array_like
Sleep stage (hypnogram). If the hypnogram is loaded, the
detection will only be applied to the value defined in
``include`` (default = N2 + N3 sleep).
The hypnogram must have the same number of samples as ``data``.
To upsample your hypnogram, please refer to
:py:func:`yasa.hypno_upsample_to_data`.
.. note::
The default hypnogram format in YASA is a 1D integer
vector where:
- -2 = Unscored
- -1 = Artefact / Movement
- 0 = Wake
- 1 = N1 sleep
- 2 = N2 sleep
- 3 = N3 sleep
- 4 = REM sleep
include : tuple, list or int
Values in ``hypno`` that will be included in the mask. The default is
(2, 3), meaning that the detection is applied on N2 and N3
sleep. This has no effect when ``hypno`` is None.
freq_sw : tuple or list
Slow wave frequency range. Default is 0.3 to 1.5 Hz. Please note that
YASA uses a FIR filter (implemented in MNE) with a 0.2 Hz transition
band, which means that the -6 dB points are located at 0.2 and 1.6 Hz.
dur_neg : tuple or list
The minimum and maximum duration of the negative deflection of the
slow wave. Default is 0.3 to 1.5 second.
dur_pos : tuple or list
The minimum and maximum duration of the positive deflection of the
slow wave. Default is 0.1 to 1 second.
amp_neg : tuple or list
Absolute minimum and maximum negative trough amplitude of the
slow-wave. Default is 40 uV to 200 uV. Can also be in unit of standard
deviations if the data has been previously z-scored. If you do not want
to specify any negative amplitude thresholds,
use ``amp_neg=(None, None)``.
amp_pos : tuple or list
Absolute minimum and maximum positive peak amplitude of the
slow-wave. Default is 10 uV to 150 uV. Can also be in unit of standard
deviations if the data has been previously z-scored.
If you do not want to specify any positive amplitude thresholds,
use ``amp_pos=(None, None)``.
amp_ptp : tuple or list
Minimum and maximum peak-to-peak amplitude of the slow-wave.
Default is 75 uV to 350 uV. Can also be in unit of standard
deviations if the data has been previously z-scored.
Use ``np.inf`` to set no upper amplitude threshold
(e.g. ``amp_ptp=(75, np.inf)``).
coupling : boolean
If True, YASA will also calculate the phase-amplitude coupling between
the slow-waves phase and the spindles-related sigma band
amplitude. Specifically, the following columns will be added to the
output dataframe:
1. ``'SigmaPeak'``: The location (in seconds) of the maximum sigma peak amplitude within a
2-seconds epoch centered around the negative peak (through) of the current slow-wave.
2. ``PhaseAtSigmaPeak``: the phase of the bandpas-filtered slow-wave signal (in radians)
at ``'SigmaPeak'``.
Importantly, since ``PhaseAtSigmaPeak`` is expressed in radians, one should use circular
statistics to calculate the mean direction and vector length:
.. code-block:: python
import pingouin as pg
mean_direction = pg.circ_mean(sw['PhaseAtSigmaPeak'])
vector_length = pg.circ_r(sw['PhaseAtSigmaPeak'])
3. ``ndPAC``: the normalized Mean Vector Length (also called the normalized direct PAC,
or ndPAC) within a 2-sec epoch centered around the negative peak of the slow-wave.
The lower and upper frequencies for the slow-waves and spindles-related sigma signals are
defined in ``freq_sw`` and ``coupling_params['freq_sp']``, respectively.
For more details, please refer to the `Jupyter notebook
<https://github.com/raphaelvallat/yasa/blob/master/notebooks/12_SO-sigma_coupling.ipynb>`_
Note that setting ``coupling=True`` may increase computation time.
.. versionadded:: 0.2.0
coupling_params : dict
Parameters for the phase-amplitude coupling.
* ``freq_sp`` is a tuple or list that defines the spindles-related frequency of interest.
The default is 12 to 16 Hz, with a wide transition bandwidth of 1.5 Hz.
* ``time`` is an int or a float that defines the time around the negative peak of each
detected slow-waves, in seconds. For example, a value of 1 means that the coupling will
be calculated for each slow-waves using a 2-seconds epoch centered around the negative
peak of the slow-waves (i.e. 1 second on each side).
* ``p`` is a parameter passed to the :py:func:`tensorpac.methods.norm_direct_pac``
function. It represents the p-value to use for thresholding of unreliable coupling
values. Sub-threshold PAC values will be set to 0. To disable this behavior (no masking),
use ``p=1`` or ``p=None``.
.. versionadded:: 0.6.0
remove_outliers : boolean
If True, YASA will automatically detect and remove outliers slow-waves
using :py:class:`sklearn.ensemble.IsolationForest`.
The outliers detection is performed on the frequency, amplitude and
duration parameters of the detected slow-waves. YASA uses a random seed
(42) to ensure reproducible results. Note that this step will only be
applied if there are more than 50 detected slow-waves in the first
place. Default to False.
verbose : bool or str
Verbose level. Default (False) will only print warning and error
messages. The logging levels are 'debug', 'info', 'warning', 'error',
and 'critical'. For most users the choice is between 'info'
(or ``verbose=True``) and warning (``verbose=False``).
.. versionadded:: 0.2.0
Returns
-------
sw : :py:class:`yasa.SWResults`
To get the full detection dataframe, use:
>>> sw = sw_detect(...)
>>> sw.summary()
This will give a :py:class:`pandas.DataFrame` where each row is a
detected slow-wave and each column is a parameter (= property).
To get the average SW parameters per channel and sleep stage:
>>> sw.summary(grp_chan=True, grp_stage=True)
Notes
-----
The parameters that are calculated for each slow-wave are:
* ``'Start'``: Start time of each detected slow-wave, in seconds from the beginning of data.
* ``'NegPeak'``: Location of the negative peak (in seconds)
* ``'MidCrossing'``: Location of the negative-to-positive zero-crossing (in seconds)
* ``'Pospeak'``: Location of the positive peak (in seconds)
* ``'End'``: End time(in seconds)
* ``'Duration'``: Duration (in seconds)
* ``'ValNegPeak'``: Amplitude of the negative peak (in uV, calculated on the ``freq_sw``
bandpass-filtered signal)
* ``'ValPosPeak'``: Amplitude of the positive peak (in uV, calculated on the ``freq_sw``
bandpass-filtered signal)
* ``'PTP'``: Peak-to-peak amplitude (= ``ValPosPeak`` - ``ValNegPeak``, calculated on the
``freq_sw`` bandpass-filtered signal)
* ``'Slope'``: Slope between ``NegPeak`` and ``MidCrossing`` (in uV/sec, calculated on the
``freq_sw`` bandpass-filtered signal)
* ``'Frequency'``: Frequency of the slow-wave (= 1 / ``Duration``)
* ``'SigmaPeak'``: Location of the sigma peak amplitude within a 2-sec epoch centered around
the negative peak of the slow-wave. This is only calculated when ``coupling=True``.
* ``'PhaseAtSigmaPeak'``: SW phase at max sigma amplitude within a 2-sec epoch centered around
the negative peak of the slow-wave. This is only calculated when ``coupling=True``
* ``'ndPAC'``: Normalized direct PAC within a 2-sec epoch centered around the negative peak
of the slow-wave. This is only calculated when ``coupling=True``
* ``'Stage'``: Sleep stage (only if hypno was provided)
.. image:: https://raw.githubusercontent.com/raphaelvallat/yasa/master/docs/pictures/slow_waves.png # noqa
:width: 500px
:align: center
:alt: slow-wave
For better results, apply this detection only on artefact-free NREM sleep.
References
----------
The slow-waves detection algorithm is based on:
* Massimini, M., Huber, R., Ferrarelli, F., Hill, S., & Tononi, G. (2004). `The sleep slow
oscillation as a traveling wave. <https://doi.org/10.1523/JNEUROSCI.1318-04.2004>`_. The
Journal of Neuroscience, 24(31), 6862–6870.
* Carrier, J., Viens, I., Poirier, G., Robillard, R., Lafortune, M., Vandewalle, G., Martin,
N., Barakat, M., Paquet, J., & Filipini, D. (2011). `Sleep slow wave changes during the
middle years of life. <https://doi.org/10.1111/j.1460-9568.2010.07543.x>`_. The European
Journal of Neuroscience, 33(4), 758–766.
Examples
--------
For an example of how to run the detection, please refer to the tutorial:
https://github.com/raphaelvallat/yasa/blob/master/notebooks/05_sw_detection.ipynb
"""
set_log_level(verbose)
(data, sf, ch_names, hypno, include, mask, n_chan, n_samples, bad_chan) = _check_data_hypno(
data, sf, ch_names, hypno, include
)
# If all channels are bad
if sum(bad_chan) == n_chan:
logger.warning("All channels have bad amplitude. Returning None.")
return None
# Define time vector
times = np.arange(data.size) / sf
idx_mask = np.where(mask)[0]
# Bandpass filter
nfast = next_fast_len(n_samples)
data_filt = filter_data(
data,
sf,
freq_sw[0],
freq_sw[1],
method="fir",
verbose=0,
l_trans_bandwidth=0.2,
h_trans_bandwidth=0.2,
)
# Extract the spindles-related sigma signal for coupling
if coupling:
is_tensorpac_installed()
import tensorpac.methods as tpm
# The width of the transition band is set to 1.5 Hz on each side,
# meaning that for freq_sp = (12, 15 Hz), the -6 dB points are located
# at 11.25 and 15.75 Hz. The frequency band for the amplitude signal
# must be large enough to fit the sidebands caused by the assumed
# modulating lower frequency band (Aru et al. 2015).
# https://doi.org/10.1016/j.conb.2014.08.002
assert isinstance(coupling_params, dict)
assert "freq_sp" in coupling_params.keys()
assert "time" in coupling_params.keys()
assert "p" in coupling_params.keys()
freq_sp = coupling_params["freq_sp"]
data_sp = filter_data(
data,
sf,
freq_sp[0],
freq_sp[1],
method="fir",
l_trans_bandwidth=1.5,
h_trans_bandwidth=1.5,
verbose=0,
)
# Now extract the instantaneous phase/amplitude using Hilbert transform
sw_pha = np.angle(signal.hilbert(data_filt, N=nfast)[:, :n_samples])
sp_amp = np.abs(signal.hilbert(data_sp, N=nfast)[:, :n_samples])
# Initialize empty output dataframe
df = pd.DataFrame()
for i in range(n_chan):
# ####################################################################
# START SINGLE CHANNEL DETECTION
# ####################################################################
# First, skip channels with bad data amplitude
if bad_chan[i]:
continue
# Find peaks in data
# Negative peaks with value comprised between -40 to -300 uV
idx_neg_peaks, _ = signal.find_peaks(-1 * data_filt[i, :], height=amp_neg)
# Positive peaks with values comprised between 10 to 200 uV
idx_pos_peaks, _ = signal.find_peaks(data_filt[i, :], height=amp_pos)
# Intersect with sleep stage vector
idx_neg_peaks = np.intersect1d(idx_neg_peaks, idx_mask, assume_unique=True)
idx_pos_peaks = np.intersect1d(idx_pos_peaks, idx_mask, assume_unique=True)
# If no peaks are detected, return None
if len(idx_neg_peaks) == 0 or len(idx_pos_peaks) == 0:
logger.warning("No SW were found in channel %s.", ch_names[i])
continue
# Make sure that the last detected peak is a positive one
if idx_pos_peaks[-1] < idx_neg_peaks[-1]:
# If not, append a fake positive peak one sample after the last neg
idx_pos_peaks = np.append(idx_pos_peaks, idx_neg_peaks[-1] + 1)
# For each negative peak, we find the closest following positive peak
pk_sorted = np.searchsorted(idx_pos_peaks, idx_neg_peaks)
closest_pos_peaks = idx_pos_peaks[pk_sorted] - idx_neg_peaks
closest_pos_peaks = closest_pos_peaks[np.nonzero(closest_pos_peaks)]
idx_pos_peaks = idx_neg_peaks + closest_pos_peaks
# Now we compute the PTP amplitude and keep only the good peaks
sw_ptp = np.abs(data_filt[i, idx_neg_peaks]) + data_filt[i, idx_pos_peaks]
good_ptp = np.logical_and(sw_ptp > amp_ptp[0], sw_ptp < amp_ptp[1])
# If good_ptp is all False
if all(~good_ptp):
logger.warning("No SW were found in channel %s.", ch_names[i])
continue
sw_ptp = sw_ptp[good_ptp]
idx_neg_peaks = idx_neg_peaks[good_ptp]
idx_pos_peaks = idx_pos_peaks[good_ptp]
# Now we need to check the negative and positive phase duration
# For that we need to compute the zero crossings of the filtered signal
zero_crossings = _zerocrossings(data_filt[i, :])
# Make sure that there is a zero-crossing after the last detected peak
if zero_crossings[-1] < max(idx_pos_peaks[-1], idx_neg_peaks[-1]):
# If not, append the index of the last peak
zero_crossings = np.append(zero_crossings, max(idx_pos_peaks[-1], idx_neg_peaks[-1]))
# Find distance to previous and following zc
neg_sorted = np.searchsorted(zero_crossings, idx_neg_peaks)
previous_neg_zc = zero_crossings[neg_sorted - 1] - idx_neg_peaks
following_neg_zc = zero_crossings[neg_sorted] - idx_neg_peaks
# Distance between the positive peaks and the previous and
# following zero-crossings
pos_sorted = np.searchsorted(zero_crossings, idx_pos_peaks)
previous_pos_zc = zero_crossings[pos_sorted - 1] - idx_pos_peaks
following_pos_zc = zero_crossings[pos_sorted] - idx_pos_peaks
# Duration of the negative and positive phases, in seconds
neg_phase_dur = (np.abs(previous_neg_zc) + following_neg_zc) / sf
pos_phase_dur = (np.abs(previous_pos_zc) + following_pos_zc) / sf
# We now compute a set of metrics
sw_start = times[idx_neg_peaks + previous_neg_zc]
sw_end = times[idx_pos_peaks + following_pos_zc]
# This should be the same as `sw_dur = pos_phase_dur + neg_phase_dur`
# We round to avoid floating point errr (e.g. 1.9000000002)
sw_dur = (sw_end - sw_start).round(4)
sw_dur_both_phase = (pos_phase_dur + neg_phase_dur).round(4)
sw_midcrossing = times[idx_neg_peaks + following_neg_zc]
sw_idx_neg = times[idx_neg_peaks] # Location of negative peak
sw_idx_pos = times[idx_pos_peaks] # Location of positive peak
# Slope between peak trough and midcrossing
sw_slope = sw_ptp / (sw_midcrossing - sw_idx_neg)
# Hypnogram
if hypno is not None:
sw_sta = hypno[idx_neg_peaks]
else:
sw_sta = np.zeros(sw_dur.shape)
# And we apply a set of thresholds to remove bad slow waves
good_sw = np.logical_and.reduce(
(
# Data edges
previous_neg_zc != 0,
following_neg_zc != 0,
previous_pos_zc != 0,
following_pos_zc != 0,
# Duration criteria
sw_dur == sw_dur_both_phase, # dur = negative + positive
sw_dur <= dur_neg[1] + dur_pos[1], # dur < max(neg) + max(pos)
sw_dur >= dur_neg[0] + dur_pos[0], # dur > min(neg) + min(pos)
neg_phase_dur > dur_neg[0],
neg_phase_dur < dur_neg[1],
pos_phase_dur > dur_pos[0],
pos_phase_dur < dur_pos[1],
# Sanity checks
sw_midcrossing > sw_start,
sw_midcrossing < sw_end,
sw_slope > 0,
)
)
if all(~good_sw):
logger.warning("No SW were found in channel %s.", ch_names[i])
continue
# Filter good events
idx_neg_peaks = idx_neg_peaks[good_sw]
idx_pos_peaks = idx_pos_peaks[good_sw]
sw_start = sw_start[good_sw]
sw_idx_neg = sw_idx_neg[good_sw]
sw_midcrossing = sw_midcrossing[good_sw]
sw_idx_pos = sw_idx_pos[good_sw]
sw_end = sw_end[good_sw]
sw_dur = sw_dur[good_sw]
sw_ptp = sw_ptp[good_sw]
sw_slope = sw_slope[good_sw]
sw_sta = sw_sta[good_sw]
# Create a dictionnary
sw_params = OrderedDict(
{
"Start": sw_start,
"NegPeak": sw_idx_neg,
"MidCrossing": sw_midcrossing,
"PosPeak": sw_idx_pos,
"End": sw_end,
"Duration": sw_dur,
"ValNegPeak": data_filt[i, idx_neg_peaks],
"ValPosPeak": data_filt[i, idx_pos_peaks],
"PTP": sw_ptp,
"Slope": sw_slope,
"Frequency": 1 / sw_dur,
"Stage": sw_sta,
}
)
# Add phase (in radians) of slow-oscillation signal at maximum
# spindles-related sigma amplitude within a XX-seconds centered epochs.
if coupling:
# Get phase and amplitude for each centered epoch
time_before = time_after = coupling_params["time"]
assert float(
sf * time_before
).is_integer(), (
"Invalid time parameter for coupling. Must be a whole number of samples."
)
bef = int(sf * time_before)
aft = int(sf * time_after)
# Center of each epoch is defined as the negative peak of the SW
n_peaks = idx_neg_peaks.shape[0]
# idx.shape = (len(idx_valid), bef + aft + 1)
idx, idx_valid = get_centered_indices(data[i, :], idx_neg_peaks, bef, aft)
sw_pha_ev = sw_pha[i, idx]
sp_amp_ev = sp_amp[i, idx]
# 1) Find location of max sigma amplitude in epoch
idx_max_amp = sp_amp_ev.argmax(axis=1)
# Now we need to append it back to the original unmasked shape
# to avoid error when idx.shape[0] != idx_valid.shape, i.e.
# some epochs were out of data bounds.
sw_params["SigmaPeak"] = np.ones(n_peaks) * np.nan
# Timestamp at sigma peak, expressed in seconds from negative peak
# e.g. -0.39, 0.5, 1, 2 -- limits are [time_before, time_after]
time_sigpk = (idx_max_amp - bef) / sf
# convert to absolute time from beginning of the recording
# time_sigpk only includes valid epoch
time_sigpk_abs = sw_idx_neg[idx_valid] + time_sigpk
sw_params["SigmaPeak"][idx_valid] = time_sigpk_abs
# 2) PhaseAtSigmaPeak
# Find SW phase at max sigma amplitude in epoch
pha_at_max = np.squeeze(np.take_along_axis(sw_pha_ev, idx_max_amp[..., None], axis=1))
sw_params["PhaseAtSigmaPeak"] = np.ones(n_peaks) * np.nan
sw_params["PhaseAtSigmaPeak"][idx_valid] = pha_at_max
# 3) Normalized Direct PAC, with thresholding
# Unreliable values are set to 0
ndp = np.squeeze(
tpm.norm_direct_pac(
sw_pha_ev[None, ...], sp_amp_ev[None, ...], p=coupling_params["p"]
)
)
sw_params["ndPAC"] = np.ones(n_peaks) * np.nan
sw_params["ndPAC"][idx_valid] = ndp
# Make sure that Stage is the last column of the dataframe
sw_params.move_to_end("Stage")
# Convert to dataframe, keeping only good events
df_chan = pd.DataFrame(sw_params)
# Remove all duplicates
df_chan = df_chan.drop_duplicates(subset=["Start"], keep=False)
df_chan = df_chan.drop_duplicates(subset=["End"], keep=False)
# We need at least 50 detected slow waves to apply the Isolation Forest
if remove_outliers and df_chan.shape[0] >= 50:
col_keep = ["Duration", "ValNegPeak", "ValPosPeak", "PTP", "Slope", "Frequency"]
ilf = IsolationForest(
contamination="auto", max_samples="auto", verbose=0, random_state=42
)
good = ilf.fit_predict(df_chan[col_keep])
good[good == -1] = 0
logger.info(
"%i outliers were removed in channel %s." % ((good == 0).sum(), ch_names[i])
)
# Remove outliers from DataFrame
df_chan = df_chan[good.astype(bool)]
logger.info("%i slow-waves were found in channel %s." % (df_chan.shape[0], ch_names[i]))
# ####################################################################
# END SINGLE CHANNEL DETECTION
# ####################################################################
df_chan["Channel"] = ch_names[i]
df_chan["IdxChannel"] = i
df = pd.concat([df, df_chan], axis=0, ignore_index=True)
# If no SW were detected, return None
if df.empty:
logger.warning("No SW were found in data. Returning None.")
return None
if hypno is None:
df = df.drop(columns=["Stage"])
else:
df["Stage"] = df["Stage"].astype(int)
return SWResults(
events=df, data=data, sf=sf, ch_names=ch_names, hypno=hypno, data_filt=data_filt
)
[docs]class SWResults(_DetectionResults):
"""Output class for slow-waves detection.
Attributes
----------
_events : :py:class:`pandas.DataFrame`
Output detection dataframe
_data : array_like
EEG data of shape *(n_chan, n_samples)*.
_data_filt : array_like
Slow-wave filtered EEG data of shape *(n_chan, n_samples)*.
_sf : float
Sampling frequency of data.
_ch_names : list
Channel names.
_hypno : array_like or None
Sleep staging vector.
"""
[docs] def __init__(self, events, data, sf, ch_names, hypno, data_filt):
super().__init__(events, data, sf, ch_names, hypno, data_filt)
[docs] def summary(self, grp_chan=False, grp_stage=False, mask=None, aggfunc="mean", sort=True):
"""Return a summary of the SW detection, optionally grouped across
channels and/or stage.
Parameters
----------
grp_chan : bool
If True, group by channel (for multi-channels detection only).
grp_stage : bool
If True, group by sleep stage (provided that an hypnogram was used).
mask : array_like or None
Custom boolean mask. Only the detected events for which mask is True will be
included in the summary. Default is None, i.e. no masking (all events are included).
aggfunc : str or function
Averaging function (e.g. ``'mean'`` or ``'median'``).
sort : bool
If True, sort group keys when grouping.
"""
return super().summary(
event_type="sw",
grp_chan=grp_chan,
grp_stage=grp_stage,
aggfunc=aggfunc,
sort=sort,
mask=mask,
)
[docs] def find_cooccurring_spindles(self, spindles, lookaround=1.2):
"""Given a spindles detection summary dataframe, find slow-waves that co-occur with
sleep spindles.
.. versionadded:: 0.6.0
Parameters
----------
spindles : :py:class:`pandas.DataFrame`
Output dataframe of :py:meth:`yasa.SpindlesResults.summary`.
lookaround : float
Lookaround window, in seconds. The default is +/- 1.2 seconds around the
negative peak of the slow-wave, as in [1]_. This means that YASA will look for a
spindle in a 2.4 seconds window centered around the downstate of the slow-wave.
Returns
-------
_events : :py:class:`pandas.DataFrame`
The slow-wave detection is modified IN-PLACE (see Notes). To see the updated dataframe,
call the :py:meth:`yasa.SWResults.summary` method.
Notes
-----
From [1]_:
"SO–spindle co-occurrence was first determined by the number of spindle centers
occurring within a ±1.2-sec window around the downstate peak of a SO, expressed as
the ratio of all detected SO events in an individual channel."
This function adds three columns to the output detection dataframe:
* `CooccurringSpindle`: a boolean column (True / False) that indicates whether the given
slow-wave co-occur with a sleep spindle.
* `CooccurringSpindlePeak`: the timestamp of the peak of the co-occurring,
in seconds from beginning of recording. Values are set to np.nan when no co-occurring
spindles were found.
* `DistanceSpindleToSW`: The distance in seconds from the center peak of the spindles and
the negative peak of the slow-waves. Negative values indicate that the spindles occured
before the negative peak of the slow-waves. Values are set to np.nan when no co-occurring
spindles were found.
References
----------
.. [1] Kurz, E. M., Conzelmann, A., Barth, G. M., Renner, T. J., Zinke, K., & Born, J.
(2021). How do children with autism spectrum disorder form gist memory during sleep?
A study of slow oscillation–spindle coupling. Sleep, 44(6), zsaa290.
"""
assert isinstance(spindles, pd.DataFrame), "spindles must be a detection dataframe."
distance_sp_to_sw_peak = []
cooccurring_spindle_peaks = []
# Find intersecting channels
common_ch = np.intersect1d(self._events["Channel"].unique(), spindles["Channel"].unique())
assert len(common_ch), "No common channel(s) were found."
# Loop across channels
for chan in self._events["Channel"].unique():
sw_chan_peaks = self._events[self._events["Channel"] == chan]["NegPeak"].to_numpy()
sp_chan_peaks = spindles[spindles["Channel"] == chan]["Peak"].to_numpy()
# Loop across individual slow-waves
for sw_negpeak in sw_chan_peaks:
start = sw_negpeak - lookaround
end = sw_negpeak + lookaround
mask = np.logical_and(start < sp_chan_peaks, sp_chan_peaks < end)
if any(mask):
# If multiple spindles are present, take the last one
sp_peak = sp_chan_peaks[mask][-1]
cooccurring_spindle_peaks.append(sp_peak)
distance_sp_to_sw_peak.append(sp_peak - sw_negpeak)
else:
cooccurring_spindle_peaks.append(np.nan)
distance_sp_to_sw_peak.append(np.nan)
# Add columns to self._events: IN-PLACE MODIFICATION!
self._events["CooccurringSpindle"] = ~np.isnan(distance_sp_to_sw_peak)
self._events["CooccurringSpindlePeak"] = cooccurring_spindle_peaks
self._events["DistanceSpindleToSW"] = distance_sp_to_sw_peak
[docs] def compare_channels(self, score="f1", max_distance_sec=0):
"""
Compare detected slow-waves across channels.
This is a wrapper around the :py:func:`yasa.compare_detection` function. Please
refer to the documentation of this function for more details.
Parameters
----------
score : str
The performance metric to compute. Accepted values are "precision", "recall"
(aka sensitivity) and "f1" (default). The F1-score is the harmonic mean of precision
and recall, and is usually the preferred metric to evaluate the agreement between
two channels. All three metrics are bounded by 0 and 1, where 1 indicates perfect
agreement.
max_distance_sec : float
The maximum distance between slow-waves, in seconds, to consider as the same event.
.. warning:: To reduce computation cost, YASA rounds the start time of each spindle to
the nearest decisecond (= 100 ms). This means that the lowest possible resolution
is 100 ms, regardless of the sampling frequency of the data. Two slow-waves
starting at 500 ms and 540 ms on their respective channels will therefore always be
considered the same event, even when max_distance_sec=0.
Returns
-------
scores : :py:class:`pandas.DataFrame`
A Pandas DataFrame with the output scores, of shape (n_chan, n_chan).
Notes
-----
Some use cases of this function:
1. What proportion of slow-waves detected in one channel are also detected on
another channel (if using ``score="recall"``).
2. What is the overall agreement in the detected events between channels?
3. Is the agreement better in channels that are close to one another?
"""
return super().compare_channels(score, max_distance_sec)
[docs] def compare_detection(self, other, max_distance_sec=0, other_is_groundtruth=True):
"""
Compare the detected slow-waves against either another YASA detection or against custom
annotations (e.g. ground-truth human scoring).
This function is a wrapper around the :py:func:`yasa.compare_detection` function. Please
refer to the documentation of this function for more details.
Parameters
----------
other : dataframe or detection results
This can be either a) the output of another YASA detection, for example if you want to
test the impact of tweaking some parameters on the detected events or b) a pandas
DataFrame with custom annotations, obtained by another detection method outside
of YASA, or with manual labelling. If b), the dataframe must contain the "Start" and
"Channel" columns, with the start of each event in seconds from the beginning
of the recording and the channel name, respectively. The channel names should match
the output of the summary() method.
max_distance_sec : float
The maximum distance between slow-waves, in seconds, to consider as the same event.
.. warning:: To reduce computation cost, YASA rounds the start time of each slow-wave
to the nearest decisecond (= 100 ms). This means that the lowest possible
resolution is 100 ms, regardless of the sampling frequency of the data.
other_is_groundtruth : bool
If True (default), ``other`` will be considered as the ground-truth scoring. If False,
the current detection will be considered as the ground-truth, and the precision and
recall scores will be inverted. This parameter has no effect on the F1-score.
.. note:: when ``other`` is the ground-truth (default), the recall score is the
fraction of events in other that were succesfully detected by the current
detection, and the precision score is the proportion of detected events by the
current detection that are also present in other.
Returns
-------
scores : :py:class:`pandas.DataFrame`
A Pandas DataFrame with the channel names as index, and the following columns
* ``precision``: Precision score, aka positive predictive value
* ``recall``: Recall score, aka sensitivity
* ``f1``: F1-score
* ``n_self``: Number of detected events in ``self`` (current method).
* ``n_other``: Number of detected events in ``other``.
Notes
-----
Some use cases of this function:
1. How well does YASA events detection perform against ground-truth human annotations?
2. If I change the threshold(s) of the events detection, do the detected events match
those obtained with the default parameters?
3. Which detection thresholds give the highest agreement with the ground-truth scoring?
"""
return super().compare_detection(other, max_distance_sec, other_is_groundtruth)
[docs] def get_coincidence_matrix(self, scaled=True):
"""Return the (scaled) coincidence matrix.
Parameters
----------
scaled : bool
If True (default), the coincidence matrix is scaled (see Notes).
Returns
-------
coincidence : pd.DataFrame
A symmetric matrix with the (scaled) coincidence values.
Notes
-----
Do slow-waves occur at the same time? One way to measure this is to
calculate the coincidence matrix, which gives, for each pair of
channel, the number of samples that were marked as a slow-waves in both
channels. The output is a symmetric matrix, in which the diagonal is
simply the number of data points that were marked as a slow-waves in
the channel.
The coincidence matrix can be scaled (default) by dividing the output
by the product of the sum of each individual binary mask, as shown in
the example below. It can then be used to define functional
networks or quickly find outlier channels.
Examples
--------
Calculate the coincidence of two binary mask:
>>> import numpy as np
>>> x = np.array([0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1])
>>> y = np.array([0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1])
>>> x * y
array([0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1])
>>> (x * y).sum() # Coincidence
3
>>> (x * y).sum() / (x.sum() * y.sum()) # Scaled coincidence
0.12
References
----------
- https://github.com/Mark-Kramer/Sleep-Networks-2021
"""
return super().get_coincidence_matrix(scaled=scaled)
[docs] def get_mask(self):
"""Return a boolean array indicating for each sample in data if this
sample is part of a detected event (True) or not (False).
"""
return super().get_mask()
[docs] def get_sync_events(
self,
center="NegPeak",
time_before=0.4,
time_after=0.8,
filt=(None, None),
mask=None,
as_dataframe=True,
):
"""
Return the raw data of each detected event after centering to a specific timepoint.
Parameters
----------
center : str
Landmark of the event to synchronize the timing on.
Default is to use the negative peak of the slow-wave.
time_before : float
Time (in seconds) before ``center``.
time_after : float
Time (in seconds) after ``center``.
filt : tuple
Optional filtering to apply to data. For instance, ``filt=(1, 30)``
will apply a 1 to 30 Hz bandpass filter, and ``filt=(None, 40)``
will apply a 40 Hz lowpass filter. Filtering is done using default
parameters in the :py:func:`mne.filter.filter_data` function.
mask : array_like or None
Custom boolean mask. Only the detected events for which mask is True will be
included. Default is None, i.e. no masking (all events are included).
as_dataframe : boolean
If True (default), returns a long-format pandas dataframe. If False, returns a list of
numpy arrays. Each element of the list a unique channel, and the shape of the numpy
arrays within the list is (n_events, n_times).
Returns
-------
df_sync : :py:class:`pandas.DataFrame` or list
Ouput long-format dataframe (if ``as_dataframe=True``)::
'Event' : Event number
'Time' : Timing of the events (in seconds)
'Amplitude' : Raw or filtered data for event
'Channel' : Channel
'IdxChannel' : Index of channel in data
'Stage': Sleep stage in which the events occured (if available)
"""
return super().get_sync_events(
center=center,
time_before=time_before,
time_after=time_after,
filt=filt,
mask=mask,
as_dataframe=as_dataframe,
)
[docs] def plot_average(
self,
center="NegPeak",
hue="Channel",
time_before=0.4,
time_after=0.8,
filt=(None, None),
mask=None,
figsize=(6, 4.5),
**kwargs,
):
"""
Plot the average slow-wave.
Parameters
----------
center : str
Landmark of the event to synchronize the timing on. The default is to use the negative
peak of the slow-wave.
hue : str
Grouping variable that will produce lines with different colors.
Can be either 'Channel' or 'Stage'.
time_before : float
Time (in seconds) before ``center``.
time_after : float
Time (in seconds) after ``center``.
filt : tuple
Optional filtering to apply to data. For instance, ``filt=(1, 30)``
will apply a 1 to 30 Hz bandpass filter, and ``filt=(None, 40)``
will apply a 40 Hz lowpass filter. Filtering is done using default
parameters in the :py:func:`mne.filter.filter_data` function.
mask : array_like or None
Custom boolean mask. Only the detected events for which mask is True will be
plotted. Default is None, i.e. no masking (all events are included).
figsize : tuple
Figure size in inches.
**kwargs : dict
Optional argument that are passed to :py:func:`seaborn.lineplot`.
"""
return super().plot_average(
event_type="sw",
center=center,
hue=hue,
time_before=time_before,
time_after=time_after,
filt=filt,
mask=mask,
figsize=figsize,
**kwargs,
)
[docs] def plot_detection(self):
"""Plot an overlay of the detected slow-waves on the EEG signal.
This only works in Jupyter and it requires the ipywidgets
(https://ipywidgets.readthedocs.io/en/latest/) package.
To activate the interactive mode, make sure to run:
>>> %matplotlib widget
.. versionadded:: 0.4.0
"""
return super().plot_detection()
#############################################################################
# REMs DETECTION
#############################################################################
[docs]def rem_detect(
loc,
roc,
sf,
hypno=None,
include=4,
amplitude=(50, 325),
duration=(0.3, 1.2),
relative_prominence=0.8,
freq_rem=(0.5, 5),
remove_outliers=False,
verbose=False,
):
"""Rapid eye movements (REMs) detection.
This detection requires both the left EOG (LOC) and right EOG (LOC).
The units of the data must be uV. The algorithm is based on an amplitude
thresholding of the negative product of the LOC and ROC
filtered signal.
.. versionadded:: 0.1.5
Parameters
----------
loc, roc : array_like
Continuous EOG data (Left and Right Ocular Canthi, LOC / ROC) channels.
Unit must be uV.
.. warning::
The default unit of :py:class:`mne.io.BaseRaw` is Volts.
Therefore, if passing data from a :py:class:`mne.io.BaseRaw`,
make sure to use units="uV" to get the data in micro-Volts, e.g.:
>>> data = raw.get_data(units="uV") # Make sure that data is in uV
sf : float
Sampling frequency of the data, in Hz.
hypno : array_like
Sleep stage (hypnogram). If the hypnogram is loaded, the
detection will only be applied to the value defined in
``include`` (default = REM sleep).
The hypnogram must have the same number of samples as ``data``.
To upsample your hypnogram, please refer to
:py:func:`yasa.hypno_upsample_to_data`.
.. note::
The default hypnogram format in YASA is a 1D integer
vector where:
- -2 = Unscored
- -1 = Artefact / Movement
- 0 = Wake
- 1 = N1 sleep
- 2 = N2 sleep
- 3 = N3 sleep
- 4 = REM sleep
include : tuple, list or int
Values in ``hypno`` that will be included in the mask. The default is
(4), meaning that the detection is applied on REM sleep.
This has no effect when ``hypno`` is None.
amplitude : tuple or list
Minimum and maximum amplitude of the peak of the REM.
Default is 50 uV to 325 uV.
duration : tuple or list
The minimum and maximum duration of the REMs.
Default is 0.3 to 1.2 seconds.
relative_prominence : float
Relative prominence used to detect the peaks. The actual prominence is computed
by multiplying relative prominence by the minimal amplitude. Default is 0.8.
freq_rem : tuple or list
Frequency range of REMs. Default is 0.5 to 5 Hz.
remove_outliers : boolean
If True, YASA will automatically detect and remove outliers REMs
using :py:class:`sklearn.ensemble.IsolationForest`.
YASA uses a random seed (42) to ensure reproducible results.
Note that this step will only be applied if there are more than
50 detected REMs in the first place. Default to False.
verbose : bool or str
Verbose level. Default (False) will only print warning and error
messages. The logging levels are 'debug', 'info', 'warning', 'error',
and 'critical'. For most users the choice is between 'info'
(or ``verbose=True``) and warning (``verbose=False``).
.. versionadded:: 0.2.0
Returns
-------
rem : :py:class:`yasa.REMResults`
To get the full detection dataframe, use:
>>> rem = rem_detect(...)
>>> rem.summary()
This will give a :py:class:`pandas.DataFrame` where each row is a
detected REM and each column is a parameter (= property).
To get the average parameters sleep stage:
>>> rem.summary(grp_stage=True)
Notes
-----
The parameters that are calculated for each REM are:
* ``'Start'``: Start of each detected REM, in seconds from the
beginning of data.
* ``'Peak'``: Location of the peak (in seconds of data)
* ``'End'``: End time (in seconds)
* ``'Duration'``: Duration (in seconds)
* ``'LOCAbsValPeak'``: LOC absolute amplitude at REM peak (in uV)
* ``'ROCAbsValPeak'``: ROC absolute amplitude at REM peak (in uV)
* ``'LOCAbsRiseSlope'``: LOC absolute rise slope (in uV/s)
* ``'ROCAbsRiseSlope'``: ROC absolute rise slope (in uV/s)
* ``'LOCAbsFallSlope'``: LOC absolute fall slope (in uV/s)
* ``'ROCAbsFallSlope'``: ROC absolute fall slope (in uV/s)
* ``'Stage'``: Sleep stage (only if hypno was provided)
Note that all the output parameters are computed on the filtered LOC and
ROC signals.
For better results, apply this detection only on artefact-free REM sleep.
References
----------
The rapid eye movements detection algorithm is based on:
* Agarwal, R., Takeuchi, T., Laroche, S., & Gotman, J. (2005).
`Detection of rapid-eye movements in sleep studies.
<https://doi.org/10.1109/TBME.2005.851512>`_
IEEE Transactions on Bio-Medical Engineering, 52(8), 1390–1396.
* Yetton, B. D., Niknazar, M., Duggan, K. A., McDevitt, E. A., Whitehurst,
L. N., Sattari, N., & Mednick, S. C. (2016). `Automatic detection of
rapid eye movements (REMs): A machine learning approach.
<https://doi.org/10.1016/j.jneumeth.2015.11.015>`_
Journal of Neuroscience Methods, 259, 72–82.
Examples
--------
For an example of how to run the detection, please refer to
https://github.com/raphaelvallat/yasa/blob/master/notebooks/07_REMs_detection.ipynb
"""
set_log_level(verbose)
# Safety checks
loc = np.squeeze(np.asarray(loc, dtype=np.float64))
roc = np.squeeze(np.asarray(roc, dtype=np.float64))
assert loc.ndim == 1, "LOC must be 1D."
assert roc.ndim == 1, "ROC must be 1D."
assert loc.size == roc.size, "LOC and ROC must have the same size."
data = np.vstack((loc, roc))
(data, sf, ch_names, hypno, include, mask, n_chan, n_samples, bad_chan) = _check_data_hypno(
data, sf, ["LOC", "ROC"], hypno, include
)
# If all channels are bad
if any(bad_chan):
logger.warning("At least one channel has bad amplitude. " "Returning None.")
return None
# Bandpass filter
data_filt = filter_data(data, sf, freq_rem[0], freq_rem[1], verbose=0)
# Calculate the negative product of LOC and ROC, maximal during REM.
negp = -data_filt[0, :] * data_filt[1, :]
# Find peaks in data
# - height: required height of peaks (min and max.)
# - distance: required distance in samples between neighboring peaks.
# - prominence: required prominence of peaks.
# - wlen: limit search for bases to a specific window.
hmin, hmax = amplitude[0] ** 2, amplitude[1] ** 2
pks, pks_params = signal.find_peaks(
negp,
height=(hmin, hmax),
distance=(duration[0] * sf),
prominence=(relative_prominence * hmin),
wlen=(duration[1] * sf),
)
# Intersect with sleep stage vector
# We do that before calculating the features in order to gain some time
idx_mask = np.where(mask)[0]
pks, idx_good, _ = np.intersect1d(pks, idx_mask, True, True)
for k in pks_params.keys():
pks_params[k] = pks_params[k][idx_good]
# If no peaks are detected, return None
if len(pks) == 0:
logger.warning("No REMs were found in data. Returning None.")
return None
# Hypnogram
if hypno is not None:
# The sleep stage at the beginning of the REM is considered.
rem_sta = hypno[pks_params["left_bases"]]
else:
rem_sta = np.zeros(pks.shape)
# Calculate time features
pks_params["Start"] = pks_params["left_bases"] / sf
pks_params["Peak"] = pks / sf
pks_params["End"] = pks_params["right_bases"] / sf
pks_params["Duration"] = pks_params["End"] - pks_params["Start"]
# Time points in minutes (HH:MM:SS)
# pks_params['StartMin'] = pd.to_timedelta(pks_params['Start'], unit='s').dt.round('s') # noqa
# pks_params['PeakMin'] = pd.to_timedelta(pks_params['Peak'], unit='s').dt.round('s') # noqa
# pks_params['EndMin'] = pd.to_timedelta(pks_params['End'], unit='s').dt.round('s') # noqa
# Absolute LOC / ROC value at peak (filtered)
pks_params["LOCAbsValPeak"] = abs(data_filt[0, pks])
pks_params["ROCAbsValPeak"] = abs(data_filt[1, pks])
# Absolute rising and falling slope
dist_pk_left = (pks - pks_params["left_bases"]) / sf
dist_pk_right = (pks_params["right_bases"] - pks) / sf
locrs = (data_filt[0, pks] - data_filt[0, pks_params["left_bases"]]) / dist_pk_left
rocrs = (data_filt[1, pks] - data_filt[1, pks_params["left_bases"]]) / dist_pk_left
locfs = (data_filt[0, pks_params["right_bases"]] - data_filt[0, pks]) / dist_pk_right
rocfs = (data_filt[1, pks_params["right_bases"]] - data_filt[1, pks]) / dist_pk_right
pks_params["LOCAbsRiseSlope"] = abs(locrs)
pks_params["ROCAbsRiseSlope"] = abs(rocrs)
pks_params["LOCAbsFallSlope"] = abs(locfs)
pks_params["ROCAbsFallSlope"] = abs(rocfs)
pks_params["Stage"] = rem_sta # Sleep stage
# Convert to Pandas DataFrame
df = pd.DataFrame(pks_params)
# Make sure that the sign of ROC and LOC is opposite
df["IsOppositeSign"] = np.sign(data_filt[1, pks]) != np.sign(data_filt[0, pks])
df = df[np.sign(data_filt[1, pks]) != np.sign(data_filt[0, pks])]
# Remove bad duration
tmin, tmax = duration
good_dur = np.logical_and(pks_params["Duration"] >= tmin, pks_params["Duration"] < tmax)
df = df[good_dur]
# Keep only useful channels
df = df[
[
"Start",
"Peak",
"End",
"Duration",
"LOCAbsValPeak",
"ROCAbsValPeak",
"LOCAbsRiseSlope",
"ROCAbsRiseSlope",
"LOCAbsFallSlope",
"ROCAbsFallSlope",
"Stage",
]
]
if hypno is None:
df = df.drop(columns=["Stage"])
else:
df["Stage"] = df["Stage"].astype(int)
# We need at least 50 detected REMs to apply the Isolation Forest.
if remove_outliers and df.shape[0] >= 50:
col_keep = [
"Duration",
"LOCAbsValPeak",
"ROCAbsValPeak",
"LOCAbsRiseSlope",
"ROCAbsRiseSlope",
"LOCAbsFallSlope",
"ROCAbsFallSlope",
]
ilf = IsolationForest(contamination="auto", max_samples="auto", verbose=0, random_state=42)
good = ilf.fit_predict(df[col_keep])
good[good == -1] = 0
logger.info("%i outliers were removed.", (good == 0).sum())
# Remove outliers from DataFrame
df = df[good.astype(bool)]
logger.info("%i REMs were found in data.", df.shape[0])
df = df.reset_index(drop=True)
return REMResults(
events=df, data=data, sf=sf, ch_names=ch_names, hypno=hypno, data_filt=data_filt
)
[docs]class REMResults(_DetectionResults):
"""Output class for REMs detection.
Attributes
----------
_events : :py:class:`pandas.DataFrame`
Output detection dataframe
_data : array_like
EOG data of shape *(n_chan, n_samples)*, where the two channels are
LOC and ROC.
_data_filt : array_like
Filtered EOG data of shape *(n_chan, n_samples)*, where the two
channels are LOC and ROC.
_sf : float
Sampling frequency of data.
_ch_names : list
Channel names (= ``['LOC', 'ROC']``)
_hypno : array_like or None
Sleep staging vector.
"""
[docs] def __init__(self, events, data, sf, ch_names, hypno, data_filt):
super().__init__(events, data, sf, ch_names, hypno, data_filt)
[docs] def summary(self, grp_stage=False, mask=None, aggfunc="mean", sort=True):
"""Return a summary of the REM detection, optionally grouped across stage.
Parameters
----------
grp_stage : bool
If True, group by sleep stage (provided that an hypnogram was
used).
mask : array_like or None
Custom boolean mask. Only the detected events for which mask is True will be
included in the summary. Default is None, i.e. no masking (all events are included).
aggfunc : str or function
Averaging function (e.g. ``'mean'`` or ``'median'``).
sort : bool
If True, sort group keys when grouping.
"""
# ``grp_chan`` is always False for REM detection because the
# REMs are always detected on a combination of LOC and ROC.
return super().summary(
event_type="rem",
grp_chan=False,
grp_stage=grp_stage,
aggfunc=aggfunc,
sort=sort,
mask=mask,
)
[docs] def get_mask(self):
"""Return a boolean array indicating for each sample in data if this
sample is part of a detected event (True) or not (False).
"""
# We cannot use super() because "Channel" is not present in _events.
from yasa.others import _index_to_events
mask = np.zeros(self._data.shape, dtype=int)
idx_ev = _index_to_events(self._events[["Start", "End"]].to_numpy() * self._sf)
mask[:, idx_ev] = 1
return mask
[docs] def get_sync_events(
self, center="Peak", time_before=0.4, time_after=0.4, filt=(None, None), mask=None
):
"""
Return the raw or filtered data of each detected event after centering to a specific
timepoint.
Parameters
----------
center : str
Landmark of the event to synchronize the timing on.
Default is to use the peak of the REM.
time_before : float
Time (in seconds) before ``center``.
time_after : float
Time (in seconds) after ``center``.
filt : tuple
Optional filtering to apply to data. For instance, ``filt=(1, 30)``
will apply a 1 to 30 Hz bandpass filter, and ``filt=(None, 40)``
will apply a 40 Hz lowpass filter. Filtering is done using default
parameters in the :py:func:`mne.filter.filter_data` function.
mask : array_like or None
Custom boolean mask. Only the detected events for which mask is True will be
included. Default is None, i.e. no masking (all events are included).
Returns
-------
df_sync : :py:class:`pandas.DataFrame`
Ouput long-format dataframe::
'Event' : Event number
'Time' : Timing of the events (in seconds)
'Amplitude' : Raw or filtered data for event
'Channel' : Channel
'IdxChannel' : Index of channel in data
"""
from yasa.others import get_centered_indices
assert time_before >= 0
assert time_after >= 0
bef = int(self._sf * time_before)
aft = int(self._sf * time_after)
if any(filt):
data = mne.filter.filter_data(
self._data, self._sf, l_freq=filt[0], h_freq=filt[1], method="fir", verbose=False
)
else:
data = self._data
# Apply mask
mask = self._check_mask(mask)
masked_events = self._events.loc[mask, :]
time = np.arange(-bef, aft + 1, dtype="int") / self._sf
# Get location of peaks in data
peaks = (masked_events[center] * self._sf).astype(int).to_numpy()
# Get centered indices (here we could use second channel as well).
idx, idx_valid = get_centered_indices(data[0, :], peaks, bef, aft)
# If no good epochs are returned raise a warning
assert len(idx_valid), (
"Time before and/or time after exceed data bounds, please "
"lower the temporal window around center."
)
# Initialize empty dataframe
df_sync = pd.DataFrame()
# Loop across both EOGs (LOC and ROC)
for i, ch in enumerate(self._ch_names):
amps = data[i, idx]
df_chan = pd.DataFrame(amps.T)
df_chan["Time"] = time
df_chan = df_chan.melt(id_vars="Time", var_name="Event", value_name="Amplitude")
df_chan["Channel"] = ch
df_chan["IdxChannel"] = i
df_sync = pd.concat([df_sync, df_chan], axis=0, ignore_index=True)
return df_sync
[docs] def plot_average(
self,
center="Peak",
time_before=0.4,
time_after=0.4,
filt=(None, None),
mask=None,
figsize=(6, 4.5),
**kwargs,
):
"""
Plot the average REM.
Parameters
----------
center : str
Landmark of the event to synchronize the timing on.
Default is to use the peak of the REM.
time_before : float
Time (in seconds) before ``center``.
time_after : float
Time (in seconds) after ``center``.
filt : tuple
Optional filtering to apply to data. For instance, ``filt=(1, 30)``
will apply a 1 to 30 Hz bandpass filter, and ``filt=(None, 40)``
will apply a 40 Hz lowpass filter. Filtering is done using default
parameters in the :py:func:`mne.filter.filter_data` function.
mask : array_like or None
Custom boolean mask. Only the detected events for which mask is True will be
included. Default is None, i.e. no masking (all events are included).
figsize : tuple
Figure size in inches.
**kwargs : dict
Optional argument that are passed to :py:func:`seaborn.lineplot`.
"""
import seaborn as sns
import matplotlib.pyplot as plt
df_sync = self.get_sync_events(
center=center, time_before=time_before, time_after=time_after, filt=filt, mask=mask
)
# Start figure
fig, ax = plt.subplots(1, 1, figsize=figsize)
sns.lineplot(data=df_sync, x="Time", y="Amplitude", hue="Channel", ax=ax, **kwargs)
# ax.legend(frameon=False, loc='lower right')
ax.set_xlim(df_sync["Time"].min(), df_sync["Time"].max())
ax.set_title("Average REM")
ax.set_xlabel("Time (sec)")
ax.set_ylabel("Amplitude (uV)")
return ax
#############################################################################
# ARTEFACT DETECTION
#############################################################################
[docs]def art_detect(
data,
sf=None,
window=5,
hypno=None,
include=(1, 2, 3, 4),
method="covar",
threshold=3,
n_chan_reject=1,
verbose=False,
):
r"""
Automatic artifact rejection.
.. versionadded:: 0.2.0
Parameters
----------
data : array_like
Single or multi-channel EEG data.
Unit must be uV and shape *(n_chan, n_samples)*.
Can also be a :py:class:`mne.io.BaseRaw`, in which case ``data``
and ``sf`` will be automatically extracted,
and ``data`` will also be automatically converted from Volts (MNE)
to micro-Volts (YASA).
.. warning::
``data`` must only contains EEG channels. Please make sure to
exclude any EOG, EKG or EMG channels.
sf : float
Sampling frequency of the data in Hz.
Can be omitted if ``data`` is a :py:class:`mne.io.BaseRaw` object.
window : float
The window length (= resolution) for artifact rejection, in seconds.
Default to 5 seconds. Shorter windows (e.g. 1 or 2-seconds) will
drastically increase computation time when ``method='covar'``.
hypno : array_like
Sleep stage (hypnogram). If the hypnogram is passed, the
detection will be applied separately for each of the stages defined in
``include``.
The hypnogram must have the same number of samples as ``data``.
To upsample your hypnogram, please refer to
:py:func:`yasa.hypno_upsample_to_data`.
.. note::
The default hypnogram format in YASA is a 1D integer
vector where:
- -2 = Unscored
- -1 = Artefact / Movement
- 0 = Wake
- 1 = N1 sleep
- 2 = N2 sleep
- 3 = N3 sleep
- 4 = REM sleep
include : tuple, list or int
Sleep stages in ``hypno`` on which to perform the artifact rejection.
The default is ``hypno=(1, 2, 3, 4)``, meaning that the artifact
rejection is applied separately for all sleep stages, excluding wake.
This parameter has no effect when ``hypno`` is None.
method : str
Artifact detection method (see Notes):
* ``'covar'`` : Covariance-based, default for 4+ channels data
* ``'std'`` : Standard-deviation-based, default for single-channel data
threshold : float
The number of standard deviations above or below which an
epoch is considered an artifact. Higher values will result in a more
conservative detection, i.e. less rejected epochs.
n_chan_reject : int
The number of channels that must be below or above ``threshold`` on any
given epochs to consider this epoch as an artefact when
``method='std'``. The default is 1, which means that the epoch will
be marked as artifact as soon as one channel is above or below the
threshold. This may be too conservative when working with a large
number of channels (e.g.hdEEG) in which case users can increase
``n_chan_reject``. Note that this parameter only has an effect
when ``method='std'``.
verbose : bool or str
Verbose level. Default (False) will only print warning and error
messages. The logging levels are 'debug', 'info', 'warning', 'error',
and 'critical'. For most users the choice is between 'info'
(or ``verbose=True``) and warning (``verbose=False``).
.. versionadded:: 0.2.0
Returns
-------
art_epochs : array_like
1-D array of shape *(n_epochs)* where 1 = Artefact and 0 = Good.
zscores : array_like
Array of z-scores, shape is *(n_epochs)* if ``method='covar'`` and
*(n_epochs, n_chan)* if ``method='std'``.
Notes
-----
.. caution::
This function will only detect major body artefacts present on the EEG
channel. It will not detect EKG contamination or eye blinks. For more
artifact rejection tools, please refer to the `MNE Python package
<https://mne.tools/stable/auto_tutorials/preprocessing/plot_10_preprocessing_overview.html>`_.
.. tip::
For best performance, apply this function on pre-staged data and make
sure to pass the hypnogram.
Sleep stages have very different EEG signatures
and the artifect rejection will be much more accurate when applied
separately on each sleep stage.
We provide below a short description of the different methods. For
multi-channel data, and if computation time is not an issue, we recommend
using ``method='covar'`` which uses a clustering approach on
variance-covariance matrices, and therefore takes into account
not only the variance in each channel and each epoch, but also the
inter-relationship (covariance) between channel.
``method='covar'`` is however not supported for single-channel EEG or when
less than 4 channels are present in ``data``. In these cases, one can
use the much faster ``method='std'`` which is simply based on a z-scoring
of the log-transformed standard deviation of each channel and each epoch.
**1/ Covariance-based multi-channel artefact rejection**
``method='covar'`` is essentially a wrapper around the
:py:class:`pyriemann.clustering.Potato` class implemented in the
`pyRiemann package
<https://pyriemann.readthedocs.io/en/latest/index.html>`_.
The main idea of this approach is to estimate a reference covariance
matrix :math:`\bar{C}` (for each sleep stage separately if ``hypno`` is
present) and reject every epoch which is too far from this reference
matrix.
The distance of the covariance matrix of the current epoch :math:`C`
from the reference matrix is calculated using Riemannian
geometry, which is more adapted than Euclidean geometry for
symmetric positive definite covariance matrices:
.. math:: d = {\left( \sum_i \log(\lambda_i)^2 \right)}^{-1/2}
where :math:`\lambda_i` are the joint eigenvalues of :math:`C` and
:math:`\bar{C}`. The epoch with covariance matric :math:`C`
will be marked as an artifact if the distance :math:`d`
is greater than a threshold :math:`T`
(typically 2 or 3 standard deviations).
:math:`\bar{C}` is iteratively estimated using a clustering approach.
**2/ Standard-deviation-based single and multi-channel artefact rejection**
``method='std'`` is a much faster and straightforward approach which
is simply based on the distribution of the standard deviations of each
epoch. Specifically, one first calculate the standard
deviations of each epoch and each channel. Then, the resulting array of
standard deviations is log-transformed and z-scored (for each sleep
stage separately if ``hypno`` is present). Any epoch with one or more
channel exceeding the threshold will be marked as artifact.
Note that this approach is more sensitive to noise and/or the influence of
one bad channel (e.g. electrode fell off at some point during the night).
We therefore recommend that you visually inspect and remove any bad
channels prior to using this function.
References
----------
* Barachant, A., Andreev, A., & Congedo, M. (2013). `The Riemannian
Potato: an automatic and adaptive artifact detection method for online
experiments using Riemannian geometry.
<https://hal.archives-ouvertes.fr/hal-00781701/>`_ TOBI
Workshop lV, 19–20.
* Barthélemy, Q., Mayaud, L., Ojeda, D., & Congedo, M. (2019).
`The Riemannian Potato Field: A Tool for Online Signal Quality Index of
EEG. <https://doi.org/10.1109/TNSRE.2019.2893113>`_
IEEE Transactions on Neural Systems and Rehabilitation Engineering:
A Publication of the IEEE Engineering in Medicine and Biology Society,
27(2), 244–255.
* https://pyriemann.readthedocs.io/en/latest/index.html
Examples
--------
For an example of how to run the detection, please refer to
https://github.com/raphaelvallat/yasa/blob/master/notebooks/13_artifact_rejection.ipynb
"""
###########################################################################
# PREPROCESSING
###########################################################################
set_log_level(verbose)
(data, sf, _, hypno, include, _, n_chan, n_samples, _) = _check_data_hypno(
data, sf, ch_names=None, hypno=hypno, include=include, check_amp=False
)
assert isinstance(n_chan_reject, int), "n_chan_reject must be int."
assert n_chan_reject >= 1, "n_chan_reject must be >= 1."
assert n_chan_reject <= n_chan, "n_chan_reject must be <= n_chan."
# Safety check: sampling frequency and window
assert isinstance(sf, (int, float)), "sf must be int or float"
assert isinstance(window, (int, float)), "window must be int or float"
if isinstance(sf, float):
assert sf.is_integer(), "sf must be a whole number."
sf = int(sf)
win_sec = window
window = win_sec * sf # Convert window to samples
if isinstance(window, float):
assert window.is_integer(), "window * sf must be a whole number."
window = int(window)
# Safety check: hypnogram
if hypno is not None:
# Extract hypnogram with only complete epochs
idx_max_full_epoch = int(np.floor(n_samples / window))
hypno_win = hypno[::window][:idx_max_full_epoch]
# Safety checks: methods
assert isinstance(method, str), "method must be a string."
method = method.lower()
if method in ["cov", "covar", "covariance", "riemann", "potato"]:
method = "covar"
is_pyriemann_installed()
from pyriemann.estimation import Covariances, Shrinkage
from pyriemann.clustering import Potato
# Must have at least 4 channels to use method='covar'
if n_chan <= 4:
logger.warning(
"Must have at least 4 channels for method='covar'. "
"Automatically switching to method='std'."
)
method = "std"
###########################################################################
# START THE REJECTION
###########################################################################
# Remove flat channels
isflat = np.nanstd(data, axis=-1) == 0
if isflat.any():
logger.warning("Flat channel(s) were found and removed in data.")
data = data[~isflat]
n_chan = data.shape[0]
# Epoch the data (n_epochs, n_chan, n_samples)
_, epochs = sliding_window(data, sf, window=win_sec)
n_epochs = epochs.shape[0]
# We first need to identify epochs with flat data (n_epochs, n_chan)
isflat = (epochs == epochs[:, :, 1][..., None]).all(axis=-1)
# 1 when all channels are flat, 0 when none ar flat (n_epochs)
prop_chan_flat = isflat.sum(axis=-1) / n_chan
# If >= 50% of channels are flat, automatically mark as artefact
epoch_is_flat = prop_chan_flat >= 0.5
where_flat_epochs = np.nonzero(epoch_is_flat)[0]
n_flat_epochs = where_flat_epochs.size
# Now let's make sure that we have an hypnogram and an include variable
if "hypno_win" not in locals():
# [-2, -2, -2, -2, ...], where -2 stands for unscored
hypno_win = -2 * np.ones(n_epochs, dtype="float")
include = np.array([-2], dtype="float")
# We want to make sure that hypno-win and n_epochs have EXACTLY same shape
assert n_epochs == hypno_win.shape[-1], "Hypno and epochs do not match."
# Finally, we make sure not to include any flat epochs in calculation
# just using a random number that is unlikely to be picked by users
if n_flat_epochs > 0:
hypno_win[where_flat_epochs] = -111991
# Add logger info
logger.info("Number of channels in data = %i", n_chan)
logger.info("Number of samples in data = %i", n_samples)
logger.info("Sampling frequency = %.2f Hz", sf)
logger.info("Data duration = %.2f seconds", n_samples / sf)
logger.info("Number of epochs = %i" % n_epochs)
logger.info("Artifact window = %.2f seconds" % win_sec)
logger.info("Method = %s" % method)
logger.info("Threshold = %.2f standard deviations" % threshold)
# Create empty `hypno_art` vector (1 sample = 1 epoch)
epoch_is_art = np.zeros(n_epochs, dtype="int")
if method == "covar":
# Calculate the covariance matrices,
# shape (n_epochs, n_chan, n_chan)
covmats = Covariances().fit_transform(epochs)
# Shrink the covariance matrix (ensure positive semi-definite)
covmats = Shrinkage().fit_transform(covmats)
# Define Potato instance: 0 = clean, 1 = art
# To increase speed we set the max number of iterations from 10 to 100
potato = Potato(
metric="riemann", threshold=threshold, pos_label=0, neg_label=1, n_iter_max=10
)
# Create empty z-scores output (n_epochs)
zscores = np.zeros(n_epochs, dtype="float") * np.nan
for stage in include:
where_stage = np.where(hypno_win == stage)[0]
# At least 30 epochs are required to calculate z-scores
# which amounts to 2.5 minutes when using 5-seconds window
if where_stage.size < 30:
if hypno is not None:
# Only show warnig if user actually pass an hypnogram
logger.warning(
f"At least 30 epochs are required to "
f"calculate z-score. Skipping "
f"stage {stage}"
)
continue
# Apply Potato algorithm, extract z-scores and labels
zs = potato.fit_transform(covmats[where_stage])
art = potato.predict(covmats[where_stage]).astype(int)
if hypno is not None:
# Only shows if user actually pass an hypnogram
perc_reject = 100 * (art.sum() / art.size)
text = (
f"Stage {stage}: {art.sum()} / {art.size} "
f"epochs rejected ({perc_reject:.2f}%)"
)
logger.info(text)
# Append to global vector
epoch_is_art[where_stage] = art
zscores[where_stage] = zs
elif method in ["std", "sd"]:
# Calculate log-transformed standard dev in each epoch
# We add 1 to avoid log warning id std is zero (e.g. flat line)
# (n_epochs, n_chan)
std_epochs = np.log(np.nanstd(epochs, axis=-1) + 1)
# Create empty zscores output (n_epochs, n_chan)
zscores = np.zeros((n_epochs, n_chan), dtype="float") * np.nan
for stage in include:
where_stage = np.where(hypno_win == stage)[0]
# At least 30 epochs are required to calculate z-scores
# which amounts to 2.5 minutes when using 5-seconds window
if where_stage.size < 30:
if hypno is not None:
# Only show warnig if user actually pass an hypnogram
logger.warning(
f"At least 30 epochs are required to "
f"calculate z-score. Skipping "
f"stage {stage}"
)
continue
# Calculate z-scores of STD for each channel x stage
c_mean = np.nanmean(std_epochs[where_stage], axis=0, keepdims=True)
c_std = np.nanstd(std_epochs[where_stage], axis=0, keepdims=True)
zs = (std_epochs[where_stage] - c_mean) / c_std
# Any epoch with at least X channel above or below threshold
n_chan_supra = (np.abs(zs) > threshold).sum(axis=1) # >
art = (n_chan_supra >= n_chan_reject).astype(int) # >= !
if hypno is not None:
# Only shows if user actually pass an hypnogram
perc_reject = 100 * (art.sum() / art.size)
text = (
f"Stage {stage}: {art.sum()} / {art.size} "
f"epochs rejected ({perc_reject:.2f}%)"
)
logger.info(text)
# Append to global vector
epoch_is_art[where_stage] = art
zscores[where_stage, :] = zs
# Mark flat epochs as artefacts
if n_flat_epochs > 0:
logger.info(
f"Rejecting {n_flat_epochs} epochs with >=50% of channels "
f"that are flat. Z-scores set to np.nan for these epochs."
)
epoch_is_art[where_flat_epochs] = 1
# Log total percentage of epochs rejected
perc_reject = 100 * (epoch_is_art.sum() / n_epochs)
text = f"TOTAL: {epoch_is_art.sum()} / {n_epochs} epochs rejected ({perc_reject:.2f}%)"
logger.info(text)
# Convert epoch_is_art to boolean [0, 0, 1] -- > [False, False, True]
epoch_is_art = epoch_is_art.astype(bool)
return epoch_is_art, zscores
#############################################################################
# COMPARE DETECTION
#############################################################################
[docs]def compare_detection(indices_detection, indices_groundtruth, max_distance=0):
"""
Determine correctness of detected events against ground-truth events.
Parameters
----------
indices_detection : array_like
Indices of the detected events. For example, this could be the indices of the
start of the spindles, or the negative peak of the slow-waves. The indices must be in
samples, and not in seconds.
indices_groundtruth : array_like
Indices of the ground-truth events, in samples.
max_distance : int, optional
Maximum distance between indices, in samples, to consider as the same event (default = 0).
For example, if the sampling frequency of the data is 100 Hz, using `max_distance=100` will
search for a matching event 1 second before or after the current event.
Returns
-------
results : dict
A dictionary with the comparison results:
* ``tp``: True positives, i.e. actual events detected as events.
* ``fp``: False positives, i.e. non-events detected as events.
* ``fn``: False negatives, i.e. actual events not detected as events.
* ``precision``: Precision score, aka positive predictive value (see Notes)
* ``recall``: Recall score, aka sensitivity (see Notes)
* ``f1``: F1-score (see Notes)
Notes
-----`
* The precision score is calculated as TP / (TP + FP).
* The recall score is calculated as TP / (TP + FN).
* The F1-score is calculated as TP / (TP + 0.5 * (FP + FN)).
This function is inspired by the `sleepecg.compare_heartbeats
<https://sleepecg.readthedocs.io/en/stable/generated/sleepecg.compare_heartbeats.html>`_
function.
Examples
--------
A simple example. Here, `detected` refers to the indices (in the data) of the detected events.
These could be for example the index of the onset of each detected spindle. `grndtrth` refers
to the ground-truth (e.g. human-annotated) events.
>>> from yasa import compare_detection
>>> detected = [5, 12, 20, 34, 41, 57, 63]
>>> grndtrth = [5, 12, 18, 26, 34, 41, 55, 63, 68]
>>> compare_detection(detected, grndtrth)
{'tp': array([ 5, 12, 34, 41, 63]),
'fp': array([20, 57]),
'fn': array([18, 26, 55, 68]),
'precision': 0.7142857142857143,
'recall': 0.5555555555555556,
'f1': 0.625}
There are 4 true positives, 2 false positives and 4 false negatives. This gives a precision
score of 0.71 (= 5 / (5 + 2)), a recall score of 0.55 (= 5 / (5 + 4)) and a F1-score of 0.625.
The F1-score is the harmonic average of precision and recall, and should be the preferred
metric when comparing the performance of a detection against a ground-truth.
Order matters! If we set `detected` as the ground-truth, FP and FN are inverted, and same for
precision and recall. The TP and F1-score remain the same though. Therefore, when comparing two
detections (and not a detection against a ground-truth), the F1-score is the preferred metric
because it is independent of the order.
>>> compare_detection(grndtrth, detected)
{'tp': array([ 5, 12, 34, 41, 63]),
'fp': array([18, 26, 55, 68]),
'fn': array([20, 57]),
'precision': 0.7142857142857143,
'recall': 0.7142857142857143,
'f1': 0.625}
There might be some events that are very close to each other, and we would like to count them
as true positive even though they do not occur exactly at the same index. This is possible
with the `max_distance` argument, which defines the lookaround window (in samples) for
each event.
>>> compare_detection(detected, grndtrth, max_distance=2)
{'tp': array([ 5, 12, 20, 34, 41, 57, 63]),
'fp': array([], dtype=int64),
'fn': array([26, 68]),
'precision': 1.0,
'recall': 0.7777777777777778,
'f1': 0.875}
Finally, if detected is empty, all performance metrics will be set to zero, and a copy of
the groundtruth array will be returned as false negatives.
>>> compare_detection([], grndtrth)
{'tp': array([], dtype=int64),
'fp': array([], dtype=int64),
'fn': array([ 5, 12, 18, 26, 34, 41, 55, 63, 68]),
'precision': 0,
'recall': 0,
'f1': 0}
"""
# Safety check
assert all([float(i).is_integer() for i in indices_detection]) # all([]) == True
assert all([float(i).is_integer() for i in indices_groundtruth])
indices_detection = np.array(indices_detection, dtype=int) # Force copy
indices_groundtruth = np.array(indices_groundtruth, dtype=int)
assert indices_detection.ndim == 1, "detection indices must be a 1D list or array."
assert indices_groundtruth.ndim == 1, "groundtruth indices must be a 1D list or array."
assert max_distance >= 0, "max_distance must be 0 or a positive integer."
assert isinstance(max_distance, int), "max_distance must be 0 or a positive integer."
# Handle cases where indices_detection or indices_groundtruth is empty
if indices_detection.size == 0:
results = dict(
tp=np.array([], dtype=int),
fp=np.array([], dtype=int),
fn=indices_groundtruth.copy(),
precision=0,
recall=0,
f1=0,
)
return results
if indices_groundtruth.size == 0:
results = dict(
tp=np.array([], dtype=int),
fp=indices_detection.copy(),
fn=np.array([], dtype=int),
precision=0,
recall=0,
f1=0,
)
return results
# Create boolean masks
max_len = max(max(indices_detection), max(indices_groundtruth)) + 1
detection_mask = np.zeros(max_len, dtype=bool)
detection_mask[indices_detection] = 1
true_mask = np.zeros(max_len, dtype=bool)
true_mask[indices_groundtruth] = 1
# Create smoothed masks
fuzzy_filter = np.ones(max_distance * 2 + 1, dtype=bool)
if len(fuzzy_filter) >= max_len:
raise ValueError(
f"The convolution window is larger than the signal. `max_distance` should be between "
f"0 and {int(max_len / 2 - 1)} samples."
)
detection_mask_fuzzy = np.convolve(detection_mask, fuzzy_filter, mode="same")
true_mask_fuzzy = np.convolve(true_mask, fuzzy_filter, mode="same")
# Confusion matrix and performance metrics
results = {}
results["tp"] = np.where(detection_mask & true_mask_fuzzy)[0]
results["fp"] = np.where(detection_mask & ~true_mask_fuzzy)[0]
results["fn"] = np.where(~detection_mask_fuzzy & true_mask)[0]
n_tp, n_fp, n_fn = len(results["tp"]), len(results["fp"]), len(results["fn"])
results["precision"] = n_tp / (n_tp + n_fp)
results["recall"] = n_tp / (n_tp + n_fn)
results["f1"] = n_tp / (n_tp + 0.5 * (n_fp + n_fn))
return results