"""
Plotting functions of YASA.
"""
import mne
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
from lspopt import spectrogram_lspopt
from matplotlib.colors import Normalize, ListedColormap
__all__ = ["plot_hypnogram", "plot_spectrogram", "topoplot"]
[docs]def plot_hypnogram(hyp, lw=1.5, highlight="REM", fill_color=None, ax=None):
"""
Plot a hypnogram.
.. versionadded:: 0.6.0
Parameters
----------
hyp : :py:class:`yasa.Hypnogram`
A YASA hypnogram instance.
lw : float
Linewidth.
highlight : str or None
Optional stage to highlight with alternate color.
fill_color : str or None
Optional color to fill space above hypnogram line.
ax : :py:class:`matplotlib.axes.Axes`
Axis on which to draw the plot, optional.
Returns
-------
ax : :py:class:`matplotlib.axes.Axes`
Matplotlib Axes
Examples
--------
.. plot::
>>> from yasa import simulate_hypnogram
>>> import matplotlib.pyplot as plt
>>> hyp = simulate_hypnogram(tib=300, seed=11)
>>> ax = hyp.plot_hypnogram()
>>> plt.tight_layout()
.. plot::
>>> from yasa import Hypnogram
>>> values = 4 * ["W", "N1", "N2", "N3", "REM"] + ["ART", "N2", "REM", "W", "UNS"]
>>> hyp = Hypnogram(values, freq="24min").upsample("30s")
>>> ax = hyp.plot_hypnogram(lw=2, fill_color="thistle")
>>> plt.tight_layout()
.. plot::
>>> from yasa import simulate_hypnogram
>>> import matplotlib.pyplot as plt
>>> fig, axes = plt.subplots(nrows=2, figsize=(6, 4), constrained_layout=True)
>>> hyp_a = simulate_hypnogram(n_stages=3, seed=99)
>>> hyp_b = simulate_hypnogram(n_stages=3, seed=99, start="2022-01-31 23:30:00")
>>> hyp_a.plot_hypnogram(lw=1, fill_color="whitesmoke", highlight=None, ax=axes[0])
>>> hyp_b.plot_hypnogram(lw=1, fill_color="whitesmoke", highlight=None, ax=axes[1])
"""
from yasa.hypno import Hypnogram # Avoiding circular import
assert isinstance(hyp, Hypnogram), "`hypno` must be YASA Hypnogram."
# Work with a copy of the Hypnogram to not alter the original
hyp = hyp.copy()
# Increase font size while preserving original
old_fontsize = plt.rcParams["font.size"]
plt.rcParams.update({"font.size": 18})
## Remap stages to be in desired y-axis order ##
# Start with default of all allowed labels
stage_order = hyp.labels.copy()
stages_present = hyp.hypno.unique()
# Remove Art/Uns from stage order, and place back individually at front to be higher on plot
art_str = stage_order.pop(stage_order.index("ART"))
uns_str = stage_order.pop(stage_order.index("UNS"))
if "ART" in stages_present:
stage_order.insert(0, art_str)
if "UNS" in stages_present:
stage_order.insert(0, uns_str)
# Put REM after WAKE if all 5 standard stages are allowed
if hyp.n_stages == 5:
stage_order.insert(stage_order.index("WAKE") + 1, stage_order.pop(stage_order.index("REM")))
# Reset the Hypnogram mapping so any future returns have this order
hyp.mapping = {stage: i for i, stage in enumerate(stage_order)}
## Extract values to plot ##
hypno = hyp.as_int()
# Reduce to breakpoints (where stages change) to avoid drawing individual lines for every epoch
hypno = hypno[hypno.shift().ne(hypno)]
# Extract x-values (bins) and y-values to plot
yvalues = hypno.to_numpy()
if hyp.start is not None:
final_bin_edge = pd.Timestamp(hyp.start) + pd.Timedelta(hyp.duration, unit="min")
bins = np.append(hypno.index.to_list(), final_bin_edge)
bins = [mdates.date2num(b) for b in bins]
xlabel = "Time"
else:
final_bin_edge = hyp.duration * 60
bins = np.append(hyp.timedelta[hypno.index].total_seconds(), final_bin_edge)
bins /= 60 if hyp.duration <= 90 else 3600
xlabel = "Time [mins]" if hyp.duration <= 90 else "Time [hrs]"
# Make mask to draw the highlighted stage
yvals_highlight = np.ma.masked_not_equal(yvalues, hyp.mapping.get(highlight))
# Open the figure
if ax is None:
ax = plt.gca()
# Draw background filling
if fill_color is not None:
bline = hyp.mapping["WAKE"] # len(stage_order) - 1 to fill from bottom
ax.stairs(yvalues.clip(bline), bins, baseline=bline, color=fill_color, fill=True, lw=0)
# Draw main hypnogram line, highlighted stage line, and Artefact/Unscored line
ax.stairs(yvalues, bins, baseline=None, color="black", lw=lw)
if not yvals_highlight.mask.all():
ax.hlines(yvals_highlight, xmin=bins[:-1], xmax=bins[1:], color="red", lw=lw)
# Aesthetics
ax.use_sticky_edges = False
ax.margins(x=0, y=1 / len(stage_order) / 2) # 1/n_epochs/2 gives half-unit margins
ax.set_yticks(range(len(stage_order)))
ax.set_yticklabels(stage_order)
ax.set_ylabel("Stage")
ax.set_xlabel(xlabel)
ax.invert_yaxis()
ax.spines[["right", "top"]].set_visible(False)
if hyp.start is not None:
ax.xaxis.set_major_formatter(mdates.DateFormatter("%H:%M"))
ax.xaxis.set_major_locator(mdates.AutoDateLocator())
# Revert font-size
plt.rcParams.update({"font.size": old_fontsize})
return ax
[docs]def plot_spectrogram(
data,
sf,
hypno=None,
win_sec=30,
fmin=0.5,
fmax=25,
trimperc=2.5,
cmap="RdBu_r",
vmin=None,
vmax=None,
**kwargs,
):
"""
Plot a full-night multi-taper spectrogram, optionally with the hypnogram on top.
For more details, please refer to the `Jupyter notebook
<https://github.com/raphaelvallat/yasa/blob/master/notebooks/10_spectrogram.ipynb>`_
.. versionadded:: 0.1.8
Parameters
----------
data : :py:class:`numpy.ndarray`
Single-channel EEG data. Must be a 1D NumPy array.
sf : float
The sampling frequency of data AND the hypnogram.
hypno : array_like
Sleep stage (hypnogram), optional.
The hypnogram must have the exact same number of samples as ``data``.
To upsample your hypnogram, please refer to :py:func:`yasa.hypno_upsample_to_data`.
.. note::
The default hypnogram format in YASA is a 1D integer
vector where:
- -2 = Unscored
- -1 = Artefact / Movement
- 0 = Wake
- 1 = N1 sleep
- 2 = N2 sleep
- 3 = N3 sleep
- 4 = REM sleep
win_sec : int or float
The length of the sliding window, in seconds, used for multitaper PSD
calculation. Default is 30 seconds. Note that ``data`` must be at least
twice longer than ``win_sec`` (e.g. 60 seconds).
fmin, fmax : int or float
The lower and upper frequency of the spectrogram. Default 0.5 to 25 Hz.
trimperc : int or float
The amount of data to trim on both ends of the distribution when
normalizing the colormap. This parameter directly impacts the
contrast of the spectrogram plot (higher values = higher contrast).
Default is 2.5, meaning that the min and max of the colormap
are defined as the 2.5 and 97.5 percentiles of the spectrogram.
cmap : str
Colormap. Default to 'RdBu_r'.
vmin : int or float
The lower range of color scale. Overwrites ``trimperc``
vmax : int or float
The upper range of color scale. Overwrites ``trimperc``
**kwargs : dict
Other arguments that are passed to :py:meth:`yasa.Hypnogram.plot_hypnogram`.
Returns
-------
fig : :py:class:`matplotlib.figure.Figure`
Matplotlib Figure
Examples
--------
1. Full-night multitaper spectrogram on Cz, no hypnogram
.. plot::
>>> import yasa
>>> import numpy as np
>>> # In the next 5 lines, we're loading the data from GitHub.
>>> import requests
>>> from io import BytesIO
>>> r = requests.get('https://github.com/raphaelvallat/yasa/raw/master/notebooks/data_full_6hrs_100Hz_Cz%2BFz%2BPz.npz', stream=True)
>>> npz = np.load(BytesIO(r.raw.read()))
>>> data = npz.get('data')[0, :]
>>> sf = 100
>>> fig = yasa.plot_spectrogram(data, sf)
2. Full-night multitaper spectrogram on Cz with the hypnogram on top
.. plot::
>>> import yasa
>>> import numpy as np
>>> # In the next lines, we're loading the data from GitHub.
>>> import requests
>>> from io import BytesIO
>>> r = requests.get('https://github.com/raphaelvallat/yasa/raw/master/notebooks/data_full_6hrs_100Hz_Cz%2BFz%2BPz.npz', stream=True)
>>> npz = np.load(BytesIO(r.raw.read()))
>>> data = npz.get('data')[0, :]
>>> sf = 100
>>> # Load the 30-sec hypnogram and upsample to data
>>> hypno = np.loadtxt('https://raw.githubusercontent.com/raphaelvallat/yasa/master/notebooks/data_full_6hrs_100Hz_hypno_30s.txt')
>>> hypno = yasa.hypno_upsample_to_data(hypno, 1/30, data, sf)
>>> fig = yasa.plot_spectrogram(data, sf, hypno, cmap='Spectral_r')
"""
from yasa.hypno import Hypnogram, hypno_int_to_str # Avoiding circular imports
# Increase font size while preserving original
old_fontsize = plt.rcParams["font.size"]
plt.rcParams.update({"font.size": 18})
# Safety checks
assert isinstance(data, np.ndarray), "`data` must be a 1D NumPy array."
assert isinstance(sf, (int, float)), "`sf` must be int or float."
assert data.ndim == 1, "`data` must be a 1D (single-channel) NumPy array."
assert isinstance(win_sec, (int, float)), "`win_sec` must be int or float."
assert isinstance(fmin, (int, float)), "`fmin` must be int or float."
assert isinstance(fmax, (int, float)), "`fmax` must be int or float."
assert fmin < fmax, "`fmin` must be strictly inferior to `fmax`."
assert fmax < sf / 2, "`fmax` must be less than Nyquist (sf / 2)."
assert isinstance(vmin, (int, float, type(None))), "`vmin` must be int, float, or None."
assert isinstance(vmax, (int, float, type(None))), "`vmax` must be int, float, or None."
if vmin is not None:
assert isinstance(vmax, (int, float)), "`vmax` must be int or float if `vmin` is provided."
if vmax is not None:
assert isinstance(vmin, (int, float)), "`vmin` must be int or float if `vmax` is provided."
if hypno is not None:
assert hypno.size == data.size, "`hypno` must have the same number of samples as `data`."
# Calculate multi-taper spectrogram
nperseg = int(win_sec * sf)
assert data.size > 2 * nperseg, "`data` length must be at least 2 * `win_sec`."
f, t, Sxx = spectrogram_lspopt(data, sf, nperseg=nperseg, noverlap=0)
Sxx = 10 * np.log10(Sxx) # Convert uV^2 / Hz --> dB / Hz
# Select only relevant frequencies (up to 30 Hz)
good_freqs = np.logical_and(f >= fmin, f <= fmax)
Sxx = Sxx[good_freqs, :]
f = f[good_freqs]
t /= 3600 # Convert t to hours
# Normalization
if vmin is None:
vmin, vmax = np.percentile(Sxx, [0 + trimperc, 100 - trimperc])
norm = Normalize(vmin=vmin, vmax=vmax)
# Open figure
if hypno is None:
fig, ax1 = plt.subplots(nrows=1, figsize=(12, 4))
else:
fig, (ax0, ax1) = plt.subplots(
nrows=2,
figsize=(12, 6),
gridspec_kw={"height_ratios": [1, 2], "hspace": 0.1},
)
# Draw Spectrogram
im = ax1.pcolormesh(t, f, Sxx, norm=norm, cmap=cmap, antialiased=True, shading="auto")
ax1.set_xlim(0, t.max())
ax1.set_ylabel("Frequency [Hz]")
ax1.set_xlabel("Time [hrs]")
if hypno is not None:
# Convert sampling frequency to pandas timefrequency string (e.g., "30s")
freq_str = pd.tseries.frequencies.to_offset(pd.Timedelta(1 / sf, "S")).freqstr
# Create Hypnogram instance for plotting
hyp = Hypnogram(hypno_int_to_str(hypno), freq=freq_str)
hypnoplot_kwargs = dict(lw=1.5, fill_color=None)
hypnoplot_kwargs.update(kwargs)
# Draw hypnogram
ax0 = hyp.plot_hypnogram(ax=ax0, **hypnoplot_kwargs)
ax0.xaxis.set_visible(False)
else:
# Add colorbar
cbar = fig.colorbar(im, ax=ax1, shrink=0.95, fraction=0.1, aspect=25)
cbar.ax.set_ylabel("Log Power (dB / Hz)", rotation=270, labelpad=20)
# Revert font-size
plt.rcParams.update({"font.size": old_fontsize})
return fig
[docs]def topoplot(
data,
montage="standard_1020",
vmin=None,
vmax=None,
mask=None,
title=None,
cmap=None,
n_colors=100,
cbar_title=None,
cbar_ticks=None,
figsize=(4, 4),
dpi=80,
fontsize=14,
**kwargs,
):
"""
Topoplot.
This is a wrapper around :py:func:`mne.viz.plot_topomap`.
For more details, please refer to this `example notebook
<https://github.com/raphaelvallat/yasa/blob/master/notebooks/15_topoplot.ipynb>`_.
.. versionadded:: 0.4.1
Parameters
----------
data : :py:class:`pandas.Series`
A pandas Series with the values to plot. The index MUST be the channel
names (e.g. ['C4', 'F4'] or ['C4-M1', 'C3-M2']).
montage : str
The name of the montage to use. Valid montages can be found at
:py:func:`mne.channels.make_standard_montage`.
vmin, vmax : float
The minimum and maximum values of the colormap. If None, these will be
defined based on the min / max values of ``data``.
mask : :py:class:`pandas.Series`
A pandas Series indicating the significant electrodes. The index MUST
be the channel names (e.g. ['C4', 'F4'] or ['C4-M1', 'C3-M2']).
title : str
The plot title.
cmap : str
A matplotlib color palette. A list of color palette can be found at:
https://seaborn.pydata.org/tutorial/color_palettes.html
n_colors : int
The number of colors to discretize the color palette.
cbar_title : str
The title of the colorbar.
cbar_ticks : list
The ticks of the colorbar.
figsize : tuple
Width, height in inches.
dpi : int
The resolution of the plot.
fontsize : int
Global font size of all the elements of the plot.
**kwargs : dict
Other arguments that are passed to :py:func:`mne.viz.plot_topomap`.
Returns
-------
fig : :py:class:`matplotlib.figure.Figure`
Matplotlib Figure
Examples
--------
1. Plot all-positive values
.. plot::
>>> import yasa
>>> import pandas as pd
>>> data = pd.Series([4, 8, 7, 1, 2, 3, 5],
... index=['F4', 'F3', 'C4', 'C3', 'P3', 'P4', 'Oz'],
... name='Values')
>>> fig = yasa.topoplot(data, title='My first topoplot')
2. Plot correlation coefficients (values ranging from -1 to 1)
.. plot::
>>> import yasa
>>> import pandas as pd
>>> data = pd.Series([-0.5, -0.7, -0.3, 0.1, 0.15, 0.3, 0.55],
... index=['F3', 'Fz', 'F4', 'C3', 'Cz', 'C4', 'Pz'])
>>> fig = yasa.topoplot(data, vmin=-1, vmax=1, n_colors=8,
... cbar_title="Pearson correlation")
"""
# Increase font size while preserving original
old_fontsize = plt.rcParams["font.size"]
plt.rcParams.update({"font.size": fontsize})
plt.rcParams.update({"savefig.bbox": "tight"})
plt.rcParams.update({"savefig.transparent": "True"})
# Make sure we don't do any in-place modification
assert isinstance(data, pd.Series), "`data` must be a Pandas Series"
data = data.copy()
# Add mask, if present
if mask is not None:
assert isinstance(mask, pd.Series), "`mask` must be a Pandas Series"
assert mask.dtype.kind in "bi", "`mask` must be True/False or 0/1."
else:
mask = pd.Series(1, index=data.index, name="mask")
# Convert to a dataframe (col1 = values, col2 = mask)
data = data.to_frame().join(mask, how="left")
# Preprocess channel names: C4-M1 --> C4
data.index = data.index.str.split("-").str.get(0)
# Define electrodes coordinates
Info = mne.create_info(data.index.tolist(), sfreq=100, ch_types="eeg")
Info.set_montage(montage, match_case=False, on_missing="ignore")
chan = Info.ch_names
# Define vmin and vmax
if vmin is None:
vmin = data.iloc[:, 0].min()
if vmax is None:
vmax = data.iloc[:, 0].max()
# Choose and discretize colormap
if cmap is None:
if vmin < 0 and vmax <= 0:
cmap = "mako"
elif vmin < 0 and vmax > 0:
cmap = "Spectral_r"
elif vmin >= 0 and vmax > 0:
cmap = "rocket_r"
cmap = ListedColormap(sns.color_palette(cmap, n_colors).as_hex())
if "sensors" not in kwargs:
kwargs["sensors"] = False
if "res" not in kwargs:
kwargs["res"] = 256
if "names" not in kwargs:
kwargs["names"] = chan
if "mask_params" not in kwargs:
kwargs["mask_params"] = dict(marker=None)
# Hidden feature: if names='values', show the actual values.
if kwargs["names"] == "values":
kwargs["names"] = data.iloc[:, 0][chan].round(2).to_numpy()
# Start the plot
with sns.axes_style("white"):
fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
if "show_names" in kwargs:
kwargs.pop("show_names")
im, _ = mne.viz.plot_topomap(
data=data.iloc[:, 0][chan],
pos=Info,
vlim=(vmin, vmax),
mask=data.iloc[:, 1][chan],
cmap=cmap,
show=False,
axes=ax,
**kwargs,
)
if title is not None:
ax.set_title(title)
# Add colorbar
if cbar_title is None:
cbar_title = data.iloc[:, 0].name
cax = fig.add_axes([0.95, 0.3, 0.02, 0.5])
cbar = fig.colorbar(im, cax=cax, ticks=cbar_ticks, fraction=0.5)
cbar.set_label(cbar_title)
# Revert font-size
plt.rcParams.update({"font.size": old_fontsize})
return fig