Skip to content

Commit

Permalink
feat(drift): added drift mvp
Browse files Browse the repository at this point in the history
Features:
- reference covariate drift
- reference label drift
- sample covariate drift
- sample label drift
- sample concept drift
  • Loading branch information
jfsantos-ds authored Jul 25, 2021
1 parent 0c7f9d3 commit 7e54a2d
Show file tree
Hide file tree
Showing 6 changed files with 941 additions and 41 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pandas==1.2.*
pydantic==1.8.2
scikit-learn==0.24.2
matplotlib==3.4.2
44 changes: 42 additions & 2 deletions src/ydata_quality/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,17 @@
import pandas as pd
from ydata_quality.core import QualityWarning
from ydata_quality.core.warnings import Priority

from ydata_quality.utils.modelling import infer_dtypes

class QualityEngine(ABC):
"Main class for running and storing data quality analysis."

def __init__(self, df: pd.DataFrame):
def __init__(self, df: pd.DataFrame, label: str = None, dtypes: dict = None):
self._df = df
self._warnings = set()
self._tests = []
self._label = label
self._dtypes = dtypes

@property
def df(self):
Expand All @@ -27,6 +29,44 @@ def warnings(self):
"Storage of all detected data quality warnings."
return self._warnings


@property
def label(self):
"Property that returns the label under inspection."
return self._label

@label.setter
def label(self, label: str):
if not isinstance(label, str):
raise ValueError("Property 'label' should be a string.")
assert label in self.df.columns, "Given label should exist as a DataFrame column."
self._label = label

@property
def dtypes(self):
"Infered dtypes for the dataset."
if self._dtypes is None:
self._dtypes = infer_dtypes(self.df)
return self._dtypes

@dtypes.setter
def dtypes(self, dtypes: dict):
if not isinstance(dtypes, dict):
raise ValueError("Property 'dtypes' should be a dictionary.")
assert all(col in self.df.columns for col in dtypes), "All dtypes keys \
must be columns in the dataset."
supported_dtypes = ['numerical', 'categorical']
assert all(dtype in supported_dtypes for dtype in dtypes.values()), "Assigned dtypes\
must be in the supported broad dtype list: {}.".format(supported_dtypes)
df_col_set = set(self.df.columns)
dtypes_col_set = set(dtypes.keys())
missing_cols = df_col_set.difference(dtypes_col_set)
if missing_cols:
_dtypes = infer_dtypes(self.df, skip=df_col_set.difference(missing_cols))
for col, dtype in _dtypes.items():
dtypes[col] = dtype
self._dtypes = dtypes

def store_warning(self, warning: QualityWarning):
"Adds a new warning to the internal 'warnings' storage."
self._warnings.add(warning)
Expand Down
9 changes: 9 additions & 0 deletions src/ydata_quality/drift/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""
Tools to check dataset for data drifting.
"""
from ydata_quality.drift.engine import DriftAnalyser, ModelWrapper

__all__ = [
"DriftAnalyser",
"ModelWrapper"
]
Loading

0 comments on commit 7e54a2d

Please sign in to comment.