Skip to content

Commit

Permalink
Add support for new __sklearn_tags__
Browse files Browse the repository at this point in the history
  • Loading branch information
stes committed Dec 16, 2024
1 parent 36a91c7 commit 77b5c85
Showing 1 changed file with 16 additions and 0 deletions.
16 changes: 16 additions & 0 deletions cebra/integrations/sklearn/cebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,10 @@
import pkg_resources
import sklearn.utils.validation as sklearn_utils_validation
import torch
import sklearn
from sklearn.base import BaseEstimator
from sklearn.base import TransformerMixin
from sklearn.utils.metaestimators import available_if
from torch import nn

import cebra.data
Expand All @@ -41,6 +43,11 @@
import cebra.models
import cebra.solver

def check_version(estimator):
# NOTE(stes): required as a check for the old way of specifying tags
# https://github.com/scikit-learn/scikit-learn/pull/29677#issuecomment-2334229165
from packaging import version
return version.parse(sklearn.__version__) < version.parse("1.6.dev")

def _init_loader(
is_cont: bool,
Expand Down Expand Up @@ -1294,6 +1301,15 @@ def fit_transform(
callback_frequency=callback_frequency)
return self.transform(X)

def __sklearn_tags__(self):
# NOTE(stes): from 1.6.dev, this is the new way to specify tags
# https://scikit-learn.org/dev/developers/develop.html
# https://github.com/scikit-learn/scikit-learn/pull/29677#issuecomment-2334229165
tags = super().__sklearn_tags__()
tags.non_deterministic = True
return tags

@available_if(check_version)
def _more_tags(self):
# NOTE(stes): This tag is needed as seeding is not fully implemented in the
# current version of CEBRA.
Expand Down

0 comments on commit 77b5c85

Please sign in to comment.