Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

yasa mice branch #72

Open
matiasandina opened this issue Jun 7, 2022 · 11 comments
Open

yasa mice branch #72

matiasandina opened this issue Jun 7, 2022 · 11 comments
Assignees
Labels
enhancement 🚧 New feature or request

Comments

@matiasandina
Copy link
Contributor

This issue contains brief details of what I changed to adapt staging.py to work with the recordings I had from mice.

The most significant change is the use of epoch_sec in get_features(), fit(), and sliding_window().

I don't remember why I kept this min() call. My epoch_sec was 2.5 seconds, so I didn't test what happens when epoch_sec is different.

win_sec = min(5, epoch_sec)  # = 2 / freq_broad[0]

I removed the temporal axis because mice don't sleep in one lump. I think this might also help with classifying human napping data. I just commented it out, but it wouldn't be difficult to put a conditional statement there or have a better solution.

        # Add temporal features
        # for mice, relying in "time since start of the night"
        # is a bad idea
        # if we remove this, we can't use the default classifier
        #features['time_hour'] = times / 3600
        #features['time_norm'] = times / times[-1]

I changed the units, which I think has been superseded in #59

        # Get data and convert to microVolts
        data = raw_pick.get_data() #* 1e6

Another minor thing is the naming of features, that hardcodes the "min" into the variable name. I would consider using "epoch" instead of "min".

In the future, I also plan to change this, because I expect to be able to run yasa in real-time.

        # Extract duration of recording in minutes
        duration_minutes = data.shape[1] / sf / 60
        assert duration_minutes >= 5, 'At least 5 minutes of data is required.'

I think these lines create problems for people used to mice data because it's usually the case that they don't use all these ratios. For my classifier, I used them and I think they contain value, but it would be nice to check whether things are present before calculating ratios.

            # Add power ratios for EEG
            # TODO: when some bands are not included, 
            # this results in key error
            if c == 'eeg':
                delta = feat['sdelta'] + feat['fdelta']
                feat['dt'] = delta / feat['theta']
                feat['ds'] = delta / feat['sigma']
                feat['db'] = delta / feat['beta']
                feat['at'] = feat['alpha'] / feat['theta']

Below everything you can find the full file.

"""Automatic sleep staging of polysomnography data."""
import os
import mne
import glob
import joblib
import logging
import numpy as np
import pandas as pd
import antropy as ant
import scipy.signal as sp_sig
import scipy.stats as sp_stats
import matplotlib.pyplot as plt
from mne.filter import filter_data
from sklearn.preprocessing import robust_scale

from yasa.others import sliding_window
from yasa.spectral import bandpower_from_psd_ndarray

logger = logging.getLogger('yasa')


class SleepStaging:
    """
    Automatic sleep staging of polysomnography data.

    To run the automatic sleep staging, you must install the
    `LightGBM <https://lightgbm.readthedocs.io/>`_ and
    `antropy <https://github.com/raphaelvallat/antropy>`_ packages.

    .. versionadded:: 0.4.0

    Parameters
    ----------
    raw : :py:class:`mne.io.BaseRaw`
        An MNE Raw instance.
    eeg_name : str
        The name of the EEG channel in ``raw``. Preferentially a central
        electrode referenced either to the mastoids (C4-M1, C3-M2) or to the
        Fpz electrode (C4-Fpz). Data are assumed to be in Volts (MNE default)
        and will be converted to uV.
    eog_name : str or None
        The name of the EOG channel in ``raw``. Preferentially,
        the left LOC channel referenced either to the mastoid (e.g. E1-M2)
        or Fpz. Can also be None.
    emg_name : str or None
        The name of the EMG channel in ``raw``. Preferentially a chin
        electrode. Can also be None.
    metadata : dict or None
        A dictionary of metadata (optional). Currently supported keys are:

        * ``'age'``: age of the participant, in years.
        * ``'male'``: sex of the participant (1 or True = male, 0 or
          False = female)

    Notes
    -----

    If you use the SleepStaging module in a publication, please cite the following publication:

    * Vallat, R., & Walker, M. P. (2021). An open-source, high-performance tool for automated
      sleep staging. Elife, 10. doi: https://doi.org/10.7554/eLife.70092

    We provide below some key points on the algorithm and its validation. For more details,
    we refer the reader to the peer-reviewed publication. If you have any questions,
    make sure to first check the
    `FAQ section <https://raphaelvallat.com/yasa/build/html/faq.html>`_ of the documentation.
    If you did not find the answer to your question, please feel free to open an issue on GitHub.

    **1. Features extraction**

    For each 30-seconds epoch and each channel, the following features are calculated:

    * Standard deviation
    * Interquartile range
    * Skewness and kurtosis
    * Number of zero crossings
    * Hjorth mobility and complexity
    * Absolute total power in the 0.4-30 Hz band.
    * Relative power in the main frequency bands (for EEG and EOG only)
    * Power ratios (e.g. delta / beta)
    * Permutation entropy
    * Higuchi and Petrosian fractal dimension

    In addition, the algorithm also calculates a smoothed and normalized version of these features.
    Specifically, a 7.5 min centered triangular-weighted rolling average and a 2 min past rolling
    average are applied. The resulting smoothed features are then normalized using a robust
    z-score.

    .. important:: The PSG data should be in micro-Volts. Do NOT transform (e.g. z-score) or filter
        the signal before running the sleep staging algorithm.

    The data are automatically downsampled to 100 Hz for faster computation.

    **2. Sleep stages prediction**

    YASA comes with a default set of pre-trained classifiers, which were trained and validated
    on ~3000 nights from the `National Sleep Research Resource <https://sleepdata.org/>`_.
    These nights involved participants from a wide age range, of different ethnicities, gender,
    and health status. The default classifiers should therefore works reasonably well on most data.

    The code that was used to train the classifiers can be found on GitHub at:
    https://github.com/raphaelvallat/yasa_classifier

    In addition with the predicted sleep stages, YASA can also return the predicted probabilities
    of each sleep stage at each epoch. This can be used to derive a confidence score at each epoch.

    .. important:: The predictions should ALWAYS be double-check by a trained
        visual scorer, especially for epochs with low confidence. A full
        inspection should be performed in the following cases:

        * Nap data, because the classifiers were exclusively trained on full-night recordings.
        * Participants with sleep disorders.
        * Sub-optimal PSG system and/or referencing

    .. warning:: N1 sleep is the sleep stage with the lowest detection accuracy. This is expected
        because N1 is also the stage with the lowest human inter-rater agreement. Be very
        careful for potential misclassification of N1 sleep (e.g. scored as Wake or N2) when
        inspecting the predicted sleep stages.

    References
    ----------
    If you use YASA's default classifiers, these are the main references for
    the `National Sleep Research Resource <https://sleepdata.org/>`_:

    * Dean, Dennis A., et al. "Scaling up scientific discovery in sleep
      medicine: the National Sleep Research Resource." Sleep 39.5 (2016):
      1151-1164.

    * Zhang, Guo-Qiang, et al. "The National Sleep Research Resource: towards
      a sleep data commons." Journal of the American Medical Informatics
      Association 25.10 (2018): 1351-1358.

    Examples
    --------
    For a concrete example, please refer to the example Jupyter notebook:
    https://github.com/raphaelvallat/yasa/blob/master/notebooks/14_automatic_sleep_staging.ipynb

    >>> import mne
    >>> import yasa
    >>> # Load an EDF file using MNE
    >>> raw = mne.io.read_raw_edf("myfile.edf", preload=True)
    >>> # Initialize the sleep staging instance
    >>> sls = yasa.SleepStaging(raw, eeg_name="C4-M1", eog_name="LOC-M2",
    ...                         emg_name="EMG1-EMG2",
    ...                         metadata=dict(age=29, male=True))
    >>> # Get the predicted sleep stages
    >>> hypno = sls.predict()
    >>> # Get the predicted probabilities
    >>> proba = sls.predict_proba()
    >>> # Get the confidence
    >>> confidence = proba.max(axis=1)
    >>> # Plot the predicted probabilities
    >>> sls.plot_predict_proba()

    The sleep scores can then be manually edited in an external graphical user interface
    (e.g. EDFBrowser), as described in the
    `FAQ <https://raphaelvallat.com/yasa/build/html/faq.html>`_.
    """

    def __init__(self, raw, eeg_name, *, eog_name=None, emg_name=None, metadata=None):
        # Type check
        assert isinstance(eeg_name, str)
        assert isinstance(eog_name, (str, type(None)))
        assert isinstance(emg_name, (str, type(None)))
        assert isinstance(metadata, (dict, type(None)))

        # Validate metadata
        if isinstance(metadata, dict):
            if 'age' in metadata.keys():
                assert 0 < metadata['age'] < 120, 'age must be between 0 and 120.'
            if 'male' in metadata.keys():
                metadata['male'] = int(metadata['male'])
                assert metadata['male'] in [0, 1], 'male must be 0 or 1.'

        # Validate Raw instance and load data
        assert isinstance(raw, mne.io.BaseRaw), 'raw must be a MNE Raw object.'
        sf = raw.info['sfreq']
        ch_names = np.array([eeg_name, eog_name, emg_name])
        ch_types = np.array(['eeg', 'eog', 'emg'])
        keep_chan = []
        for c in ch_names:
            if c is not None:
                assert c in raw.ch_names, '%s does not exist' % c
                keep_chan.append(True)
            else:
                keep_chan.append(False)
        # Subset
        ch_names = ch_names[keep_chan].tolist()
        ch_types = ch_types[keep_chan].tolist()
        # Keep only selected channels (creating a copy of Raw)
        raw_pick = raw.copy().pick_channels(ch_names, ordered=True)

        # Downsample if sf != 100
        assert sf > 80, 'Sampling frequency must be at least 80 Hz.'
        if sf != 100:
            raw_pick.resample(100, npad="auto")
            sf = raw_pick.info['sfreq']

        # Get data and convert to microVolts
        data = raw_pick.get_data() #* 1e6

        # Extract duration of recording in minutes
        duration_minutes = data.shape[1] / sf / 60
        assert duration_minutes >= 5, 'At least 5 minutes of data is required.'

        # Add to self
        self.sf = sf
        self.ch_names = ch_names
        self.ch_types = ch_types
        self.data = data
        self.metadata = metadata

    def fit(self, epoch_sec=30, bands=None):
        """Extract features from data.
        Returns
        -------
        self : returns an instance of self.
        epoch_sec: Time window in seconds to be used for feature extraction. Defaults to 30 seconds.
        """
        #######################################################################
        # MAIN PARAMETERS
        #######################################################################
    
        # Bandpass filter
        freq_broad = (0.4, 30)
        # FFT & bandpower parameters
        win_sec = min(5, epoch_sec)  # = 2 / freq_broad[0]
        sf = self.sf
        win = int(win_sec * sf)
        kwargs_welch = dict(window='hamming', nperseg=win, average='median')
        if bands is None:
          bands = [
              (0.4, 1, 'sdelta'), (1, 4, 'fdelta'), (4, 8, 'theta'),
              (8, 12, 'alpha'), (12, 16, 'sigma'), (16, 30, 'beta')
          ]
    
        #######################################################################
        # CALCULATE FEATURES
        #######################################################################
    
        features = []
    
        for i, c in enumerate(self.ch_types):
            # Preprocessing
            # - Filter the data
            dt_filt = filter_data(
                self.data[i, :], sf, l_freq=freq_broad[0], h_freq=freq_broad[1], verbose=False)
            # - Extract epochs. Data is now of shape (n_epochs, n_samples).
            times, epochs = sliding_window(dt_filt, sf=sf, window=epoch_sec)
    
            # Calculate standard descriptive statistics
            hmob, hcomp = ant.hjorth_params(epochs, axis=1)
    
            feat = {
                'std': np.std(epochs, ddof=1, axis=1),
                'iqr': sp_stats.iqr(epochs, rng=(25, 75), axis=1),
                'skew': sp_stats.skew(epochs, axis=1),
                'kurt': sp_stats.kurtosis(epochs, axis=1),
                'nzc': ant.num_zerocross(epochs, axis=1),
                'hmob': hmob,
                'hcomp': hcomp
            }
    
            # Calculate spectral power features (for EEG + EOG)
            freqs, psd = sp_sig.welch(epochs, sf, **kwargs_welch)
            if c != 'emg':
                bp = bandpower_from_psd_ndarray(psd, freqs, bands=bands)
                for j, (_, _, b) in enumerate(bands):
                    feat[b] = bp[j]
    
            # Add power ratios for EEG
            # TODO: when some bands are not included, 
            # this results in key error
            if c == 'eeg':
                delta = feat['sdelta'] + feat['fdelta']
                feat['dt'] = delta / feat['theta']
                feat['ds'] = delta / feat['sigma']
                feat['db'] = delta / feat['beta']
                feat['at'] = feat['alpha'] / feat['theta']
    
            # Add total power
            idx_broad = np.logical_and(freqs >= freq_broad[0], freqs <= freq_broad[1])
            dx = freqs[1] - freqs[0]
            feat['abspow'] = np.trapz(psd[:, idx_broad], dx=dx)
    
            # Calculate entropy and fractal dimension features
            feat['perm'] = np.apply_along_axis(
                ant.perm_entropy, axis=1, arr=epochs, normalize=True)
            feat['higuchi'] = np.apply_along_axis(
                ant.higuchi_fd, axis=1, arr=epochs)
            feat['petrosian'] = ant.petrosian_fd(epochs, axis=1)
    
            # Convert to dataframe
            feat = pd.DataFrame(feat).add_prefix(c + '_')
            features.append(feat)
    
        #######################################################################
        # SMOOTHING & NORMALIZATION
        #######################################################################
    
        # Save features to dataframe
        features = pd.concat(features, axis=1)
        features.index.name = 'epoch'
        
        # TODO: change here, rolling windows are hardcoded
        # and assume epoch = 30 sec
        # this will change when epochs change
        # I would consider changing this to '_c15epoch_norm'
        # Apply centered rolling average (15 epochs = 7 min 30)
        # Triang: [0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.,
        #          0.875, 0.75, 0.625, 0.5, 0.375, 0.25, 0.125]
        rollc = features.rolling(
            window=15, center=True, min_periods=1, win_type='triang').mean()
        rollc[rollc.columns] = robust_scale(rollc, quantile_range=(5, 95))
        rollc = rollc.add_suffix('_c7min_norm')
    
        # Now look at the past 2 minutes
        rollp = features.rolling(window=4, min_periods=1).mean()
        rollp[rollp.columns] = robust_scale(rollp, quantile_range=(5, 95))
        rollp = rollp.add_suffix('_p2min_norm')
    
        # Add to current set of features
        features = features.join(rollc).join(rollp)
    
        #######################################################################
        # TEMPORAL + METADATA FEATURES AND EXPORT
        #######################################################################
    
        # Add temporal features
        # for mice, relying in "time since start of the night"
        # is a bad idea
        # if we remove this, we can't use the default classifier
        #features['time_hour'] = times / 3600
        #features['time_norm'] = times / times[-1]
    
        # Add metadata if present
        if self.metadata is not None:
            for c in self.metadata.keys():
                features[c] = self.metadata[c]
    
        # Downcast float64 to float32 (to reduce size of training datasets)
        cols_float = features.select_dtypes(np.float64).columns.tolist()
        features[cols_float] = features[cols_float].astype(np.float32)
        # Make sure that age and sex are encoded as int
        if 'age' in features.columns:
            features['age'] = features['age'].astype(int)
        if 'male' in features.columns:
            features['male'] = features['male'].astype(int)
    
        # Sort the column names here (same behavior as lightGBM)
        features.sort_index(axis=1, inplace=True)
    
        # Add to self
        self._features = features
        self.feature_name_ = self._features.columns.tolist()

    def get_features(self, epoch_sec=30, bands=None):
        """Extract features from data and return a copy of the dataframe.

        Returns
        -------
        features : :py:class:`pandas.DataFrame`
            Feature dataframe.
        """
        if not hasattr(self, '_features'):
            self.fit(epoch_sec, bands)
        return self._features.copy()

    def _validate_predict(self, clf):
        """Validate classifier."""
        # Check that we're using exactly the same features
        # Note that clf.feature_name_ is only available in lightgbm>=3.0
        f_diff = np.setdiff1d(clf.feature_name_, self.feature_name_)
        if len(f_diff):
            raise ValueError("The following features are present in the "
                             "classifier but not in the current features set:", f_diff)
        f_diff = np.setdiff1d(self.feature_name_, clf.feature_name_, )
        if len(f_diff):
            raise ValueError("The following features are present in the "
                             "current feature set but not in the classifier:", f_diff)

    def _load_model(self, path_to_model):
        """Load the relevant trained classifier."""
        if path_to_model == "auto":
            from pathlib import Path
            clf_dir = os.path.join(str(Path(__file__).parent), 'classifiers/')
            name = 'clf_eeg'
            name = name + '+eog' if 'eog' in self.ch_types else name
            name = name + '+emg' if 'emg' in self.ch_types else name
            name = name + '+demo' if self.metadata is not None else name
            # e.g. clf_eeg+eog+emg+demo_lgb_0.4.0.joblib
            all_matching_files = glob.glob(clf_dir + name + "*.joblib")
            # Find the latest file
            path_to_model = np.sort(all_matching_files)[-1]
        # Check that file exists
        assert os.path.isfile(path_to_model), "File does not exist."
        logger.info("Using pre-trained classifier: %s" % path_to_model)
        # Load using Joblib
        clf = joblib.load(path_to_model)
        # Validate features
        self._validate_predict(clf)
        return clf

    def predict(self, path_to_model="auto", epoch_sec=30, bands=None):
        """
        Return the predicted sleep stage for each 30-sec epoch of data.

        Currently, only classifiers that were trained using a
        `LGBMClassifier <https://lightgbm.readthedocs.io/en/latest/pythonapi/lightgbm.LGBMClassifier.html>`_
        are supported.

        Parameters
        ----------
        path_to_model : str or "auto"
            Full path to a trained LGBMClassifier, exported as a joblib file. Can be "auto" to
            use YASA's default classifier.

        Returns
        -------
        pred : :py:class:`numpy.ndarray`
            The predicted sleep stages.
        """
        if not hasattr(self, '_features'):
            self.fit(epoch_sec, bands)
        # Load and validate pre-trained classifier
        clf = self._load_model(path_to_model)
        # Now we make sure that the features are aligned
        X = self._features.copy()[clf.feature_name_]
        # Predict the sleep stages and probabilities
        self._predicted = clf.predict(X)
        proba = pd.DataFrame(clf.predict_proba(X), columns=clf.classes_)
        proba.index.name = 'epoch'
        self._proba = proba
        return self._predicted.copy()

    def predict_proba(self, path_to_model="auto"):
        """
        Return the predicted probability for each sleep stage for each 30-sec epoch of data.

        Currently, only classifiers that were trained using a
        `LGBMClassifier <https://lightgbm.readthedocs.io/en/latest/pythonapi/lightgbm.LGBMClassifier.html>`_
        are supported.

        Parameters
        ----------
        path_to_model : str or "auto"
            Full path to a trained LGBMClassifier, exported as a joblib file. Can be "auto" to
            use YASA's default classifier.

        Returns
        -------
        proba : :py:class:`pandas.DataFrame`
            The predicted probability for each sleep stage for each 30-sec epoch of data.
        """
        if not hasattr(self, '_proba'):
            self.predict(path_to_model)
        return self._proba.copy()

    def plot_predict_proba(self, proba=None, majority_only=False,
                           palette=['#99d7f1', '#009DDC', 'xkcd:twilight blue',
                                    'xkcd:rich purple', 'xkcd:sunflower']):
        """
        Plot the predicted probability for each sleep stage for each 30-sec epoch of data.

        Parameters
        ----------
        proba : self or DataFrame
            A dataframe with the probability of each sleep stage for each 30-sec epoch of data.
        majority_only : boolean
            If True, probabilities of the non-majority classes will be set to 0.
        """
        if proba is None and not hasattr(self, '_features'):
            raise ValueError("Must call .predict_proba before this function")
        if proba is None:
            proba = self._proba.copy()
        else:
            assert isinstance(proba, pd.DataFrame), 'proba must be a dataframe'
        if majority_only:
            cond = proba.apply(lambda x: x == x.max(), axis=1)
            proba = proba.where(cond, other=0)
        ax = proba.plot(kind='area', color=palette, figsize=(10, 5), alpha=.8, stacked=True, lw=0)
        # Add confidence
        # confidence = proba.max(1)
        # ax.plot(confidence, lw=1, color='k', ls='-', alpha=0.5,
        #         label='Confidence')
        ax.set_xlim(0, proba.shape[0])
        ax.set_ylim(0, 1)
        ax.set_ylabel("Probability")
        ax.set_xlabel("Time (30-sec epoch)")
        plt.legend(frameon=False, bbox_to_anchor=(1, 1))
        return ax

@raphaelvallat raphaelvallat added the enhancement 🚧 New feature or request label Jun 8, 2022
@raphaelvallat
Copy link
Owner

This is great @matiasandina!! I'm sure this will be super useful to others. Can I ask: did you re-train a new classifier using these updated features? If so, can you describe the training set as well? And more importantly, would you be willing to share the training set and/or the trained classifier (lightgbm tree paths)?

@matiasandina
Copy link
Contributor Author

matiasandina commented Jun 8, 2022

The dataset was collected from OSF. It contains mouse EEG/EMG recordings (sampling rate: 512 Hz) and sleep stage labels (epoch length: 2.5 sec).
Training was performed using a portion of the open source dataset and testing was performed using another portion (partition is not random but training and testing sets are separate). Training was performed using extracted features from 24h recordings.
I am happy to share everything, what's the best way to go about it?
As of now, the python code with procedure is split in 3 .qmd files (https://quarto.org/). I can make a repo contribution but the dataset needs to be independently downloaded and formatted into folders as expected by the code.
I also have the generated classifier with epoch = 2.5 seconds.
This classifier might be improved a bit by smoothing of fast transitions that I didn't get to do it, but there are a few very fast transitions that are unlikely (a 2.5 sec bout of NREM in the middle of wake).

@raphaelvallat
Copy link
Owner

Thanks @matiasandina! I think that if we are to create a new branch (like "yasa_mice"), the minimum that we need is:

And then, we can create a separate repo (see https://github.com/raphaelvallat/yasa_classifier as an example) with the code to reproduce the trained classifier, i.e. model training, data partitioning, performance evaluation, etc.

A few questions:

  • What is the performance of the algorithm on the test set?
  • Is there a single EEG channel in this dataset? Do you think that the choice of the electrode will have a strong impact for mice sleep?

@matiasandina
Copy link
Contributor Author

matiasandina commented Jun 10, 2022

This is what I can do for now, please let me know if you can reproduce it.
https://github.com/matiasandina/yasa_classifier


Performance is quite high, above 90 for accuracy and a bit lower for cohen's kappa. I think it can be smoothed out to gain even more performance but I'm OK so far. You can find performance values in this folder, which includes accuracy, cohen's kappa, and classification matrices for each of the 50 4h recordings that I used for testing.

The dataset contains one EEG and one EMG. I believe the choice of electrode will for sure affect the algorithm. Again, data in mice is not as standard. Even for labs that have detailed information, they see that the brain itself enters stages at slightly different moments. Below a result from LFP electrodes by Soltani 2019

sleep_electrode_array

I am collecting with a "high throughput" (9 EEG + 2 EMG), so I will have more to say about this in the future.


Most critical thing is, classifier accuracy will be extremely dependent on the nature of the data itself. Mice data is not standard (in the same way human data is somewhat standardized). I expect all labs to need to re-train with data they generated themselves before getting good results with any classifier. I have not done so because I haven't collected enough data yet, but I don't see this as being an issue in the future (provided me and a few others in lab can label our own data!).

An alternative would be to find ways of normalizing datasets so that the feature extractor would extract the same values for features in all. This sounds easier said than donw, but it could be great in light of people sharing open datasets like the one I used for training.

@raphaelvallat
Copy link
Owner

This is really great work @matiasandina, thank you!

Mice data is not standard

Could you say more? What if we only include features that are normalized to the mean and unit variance of the recording, i.e. we normalize data amplitude across recordings?

I can do this inside a yasa_mice branch on my repo if you want to evaluate later whether to merge or what.

I am still undecided on what is the best way to go, i.e. a separate mice branch in YASA (which would however require to fork the repo on GitHub and install the develop version of YASA), or a whole new repository pip install yasa_mice. In both cases, we will need to keep the fork/repo up to date with the main YASA repo. A third option could be to add a species="human" parameter to the yasa.SleepStaging class, which would automatically switch to the relevant features and classifier. This will add some burden to the code however.

or each of the 50 4h recordings that I used for testing.

Do you think that the length of the input data will change the output? Can you just pass 24-hour recordings or even multi-day recording to the sleep staging function?

If there are mice researchers following this thread, please feel free to chime in with your ideas and preferences :)

@matiasandina
Copy link
Contributor Author

Mice data is not standardized in the same way human data might be. A few quick examples of these.

  1. There is no topoplot because people use rough coordinates for electrodes, not standard. In humans each electrode has a name, you can plot this sort of thing:

image

This is not true for mice. The coordinates are usually modified as per what the experimenter wants (one valid reason to do so is the fact that you might want to implant something in addition to the EEGs into the brain and you don't have space if you follow what others have done).

  1. The nature of the recordings are different (people use different materials for screws in the skull or different implants; people use different acquisition and different filtering methods, a vast majority of these are made in-house and criteria are established by each experimenter).

  2. If not acquiring data with commercial solutions, the structure of the data can be whatever your acquisition gives (for example, I acquire in ColumnMajor with no metadata). Again, there might be good reasons to chose to do this (for example there are not a lot of solid commercial alternatives that can handle multiple streams of information in real-time and control the hardware I need to control for my experiments).

This is not to say the data is corrupted or low quality, it's just less industrial (?), less plug and play as I imagine human data would be.

Do you think that the length of the input data will change the output? Can you just pass 24-hour recordings or even multi-day recording to the sleep staging function?

The training was done with 24 h recordings. I don't think the length of the input data would change the output. I don't have multi-day at hand right now but it would be nice to try.


Regarding the branch thing...Maybe the question is whether the branches would diverge so much that it's more burden to code handlers than to split them. I think it might be worth to keep it all together. Not sure how they implemented this in code, but people from Deep Lab Cut have taken the route of species = x for their dlc-modelzoo.

@raphaelvallat
Copy link
Owner

@matiasandina thanks for the detailed explanation! Another naive question: is mice sleep similar to rat sleep? Do you think your classifier would work well on rats data?

Maybe another option is to have some sort of configurable file that determine the features, e.g. which feature to compute, the length of the epoch, the length of the smoothing window, etc. That way, you just need the config file (*.json or *.yaml) and the updated classifier to run YASA on another species.

@matiasandina
Copy link
Contributor Author

Sorry for the late reply. Provided re-training, I think this is flexible to work multi-species. I like the idea of a config file that determines the features! It's been a bit difficult to find some time to work on this since coming back from vacation and trying to get my PhD in motion again, but it's on the list!

@BryanWang0702
Copy link

Hi, I am also trying to apply the yasa to mice data for auto stagin, and I am wondering how much progress you have until now. I applied 5-s window to slide split the EEG and EMG data. And I applied the features as followed:

  • Time domain
    • Standard deviation
    • Zero cross rate
    • Hjorth Parameters
      • Mobility
      • Complexity
    • Permutation entropy
    • Skewness and Kurtosis (only EEG)
  • Frequence domain
    • Delta/theta ratio
    • Theta band power

I used the 14 features and trained some lightGBM models (Because we have different types of mice, so I trained several model) based on our own human labeled data. Consistent to the paper of YASA, I used grid search for the hyperparams. I got the model can predict Wake and NREM for over 0.8 in F1 Score, but for REM the median F1 Score is only about 0.5. See the picture below.
图片

I think the REM stage is the easiest stage for detection, but somehow the REM accuracy is the lowest. I don't know do you have the situation same with me? @matiasandina .

And then I checked the model prediction probability for all stages, and found seems that the REM probility is always keep in low level. I don't know is this is the normal situation.
图片

Same situation for me with @matiasandina , I also found the result of prediction is fragmanted #139 . And I added some constraints to make it more smooth, like this:
图片

I want to solve the REM low accuracy issue, do you have any suggestion? @raphaelvallat

Thank you,
Xueqiang

@matiasandina
Copy link
Contributor Author

At some point, I started facing significant bias towards Wake. This problem was much more damaging to the project than the fragmentation and other issues I previously mentioned and stopped me from trying to continue to integrate this into the package.

I tried to re-train the classifier many times (see here). Though I could never get to the bottom of why there was some sort of data drift or change in the classifier's behavior.

From my classifiers, I'm getting very high wake percentages (around 70-80%). This is specially bad since a single feature classifier based on logRMS(EMG) performs "quite well" to distinguish wake from sleep. Even adding this feature for the classifier did not seem to matter much./

As I mentioned before either on this issue on another one, the fact that the lightGBM uses computed features is sensitive to data drift and I think in order to really use this a lab should ensure train-test set come from the same recording devices. Upon manual inspection, the open training data and my data do not differ in ways that a human would have issues with (or that a I attempt to quantify with features). And yet, the yasa classifier I trained went bananas.

When it was working better, REM was somewhat shorter than what a human would do, and the transitions would be quite strange (e.g., NREM -> W -> REM -> NREM -> W instead of NREM -> REM -> REM -> REM -> W). This last point is understandable since there are no rules to rein this behavior (W->REM transitions are allowed).

I am happy to talk about this at length and share data/ideas, but I hit a bit of a dead end on my side

@BryanWang0702
Copy link

Well, I see. We got different recording devices, but I found the lightGBM classifiers have almost the same performance among different device data. See the comment above, the mousetype3 data was recorded by different device.

In my situation, I got the EEG/EMG data of different ages' mice, so I trained a set of lightGBM models across age. See auto_stage_model. And I distinguished the Frontal and Parietal, because the REM is more specific in parietal.

I tried to re-train the classifier many times (see [here](https://github.com/matiasandina/yasa_classifier/tree/yasa-accusleep-eval/output/classifiers/accusleep_512)). Though I could never get to the bottom of why there was some sort of data drift or change in the classifier's behavior.
Nice try to have different theta range, I also tried this and no significant difference.

I also noticed that the wake data is much more than NREM and REM, my proportion of Wake, NREM and REM is about 14:7:1. So I gathered the training data of the same proportion to 1:1:1, and retrained the model, found there is no significant difference.

My wrong cases on REM detection is more like REM -> Wake -> REM, as I add the constraints no REM after Wake the result then turned to REM -> Wake -> NREM -> REM, which makes it worse. I think the model is more likely to predict the REM to Wake in my situation, maybe we can add the feature of alpha band or gamma band to distinguish the REM state and Wake, I don't know and I will try it.

Maybe you have heard about DeepLabCut, they have a fine-tune process with own video data. I was trying to follow that, and the idea is like 'I label some data and use the labeled data to fine-tune the lightGBM model', but found that small set of labeled data cannot tune the lightGBM model. So I am tring other strategies, but I think this can be a good idea. And I think if apply the CNN to train model, maybe we can fine-tune it.

I saw you trianed the model based on AccuSleep? Is that perform better on mice data? I haven't try that yet.

Thanks,
Xueqiang

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement 🚧 New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants