From 1f8e09d346b43c9dc52648ee0df2f586f0c1868d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A9d=C3=A9ric=20Hurier=20=28Fmind=29?= Date: Sat, 14 Dec 2024 10:07:56 +0100 Subject: [PATCH] refactor(code): improve abstractions --- src/bikes/io/__init__.py | 2 +- src/bikes/io/datasets.py | 3 ++- src/bikes/io/services.py | 2 +- src/bikes/utils/signers.py | 2 +- src/bikes/utils/splitters.py | 4 +++- 5 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/bikes/io/__init__.py b/src/bikes/io/__init__.py index 2a8a67d..044aa70 100644 --- a/src/bikes/io/__init__.py +++ b/src/bikes/io/__init__.py @@ -1 +1 @@ -"""Components related to external operations.""" +"""Components related to external operations (inputs and outputs).""" diff --git a/src/bikes/io/datasets.py b/src/bikes/io/datasets.py index 7cf2933..677c654 100644 --- a/src/bikes/io/datasets.py +++ b/src/bikes/io/datasets.py @@ -69,11 +69,12 @@ class ParquetReader(Reader): KIND: T.Literal["ParquetReader"] = "ParquetReader" path: str + backend: T.Literal["pyarrow", "numpy_nullable"] = "pyarrow" @T.override def read(self) -> pd.DataFrame: # can't limit rows at read time - data = pd.read_parquet(self.path) + data = pd.read_parquet(self.path, dtype_backend="pyarrow") if self.limit is not None: data = data.head(self.limit) return data diff --git a/src/bikes/io/services.py b/src/bikes/io/services.py index 8a2c572..5743bef 100644 --- a/src/bikes/io/services.py +++ b/src/bikes/io/services.py @@ -199,7 +199,7 @@ def run_context(self, run_config: RunConfig) -> T.Generator[mlflow.ActiveRun, No run (str): run parameters. Yields: - T.Generator[mlflow.ActiveRun, None, None]: active run context. Will be closed as the end of context. + T.Generator[mlflow.ActiveRun, None, None]: active run context. Will be closed at the end of context. """ with mlflow.start_run( run_name=run_config.name, diff --git a/src/bikes/utils/signers.py b/src/bikes/utils/signers.py index b976bb5..4a0a5ce 100644 --- a/src/bikes/utils/signers.py +++ b/src/bikes/utils/signers.py @@ -21,7 +21,7 @@ class Signer(abc.ABC, pdt.BaseModel, strict=True, frozen=True, extra="forbid"): """Base class for generating model signatures. - Allow to switch between model signing strategies. + Allow switching between model signing strategies. e.g., automatic inference, manual model signature, ... https://mlflow.org/docs/latest/models.html#model-signature-and-input-example diff --git a/src/bikes/utils/splitters.py b/src/bikes/utils/splitters.py index bee5952..0740bc0 100644 --- a/src/bikes/utils/splitters.py +++ b/src/bikes/utils/splitters.py @@ -132,7 +132,9 @@ def split( targets: schemas.Targets, groups: Index | None = None, ) -> TrainTestSplits: - splitter = model_selection.TimeSeriesSplit(n_splits=self.n_splits, test_size=self.test_size) + splitter = model_selection.TimeSeriesSplit( + n_splits=self.n_splits, test_size=self.test_size, gap=self.gap + ) yield from splitter.split(inputs) @T.override