Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: dataset Prediction Id and Timestamp normalization #166

Merged
merged 28 commits into from
Jan 24, 2023
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 26 additions & 5 deletions src/phoenix/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from dataclasses import fields, replace
from typing import Any, Dict, List, Optional, Set, Tuple, Union

from pandas import DataFrame, Series, read_parquet
from pandas import DataFrame, Series, Timestamp, read_parquet, to_datetime
from pandas.api.types import is_numeric_dtype

from phoenix.config import dataset_dir

Expand Down Expand Up @@ -268,7 +269,7 @@ def _parse_dataframe_and_schema(dataframe: DataFrame, schema: Schema) -> Tuple[D
"not found in the dataframe: {}".format(", ".join(unseen_excluded_column_names))
)

parsed_dataframe, parsed_schema = _create_parsed_dataframe_and_schema(
parsed_dataframe, parsed_schema = _create_and_normalize_dataframe_and_schema(
dataframe, schema, schema_patch, column_name_to_include
)

Expand Down Expand Up @@ -400,20 +401,40 @@ def _discover_feature_columns(
)


def _create_parsed_dataframe_and_schema(
def _create_and_normalize_dataframe_and_schema(
dataframe: DataFrame,
schema: Schema,
schema_patch: Dict[SchemaFieldName, SchemaFieldValue],
column_name_to_include: Dict[str, bool],
) -> Tuple[DataFrame, Schema]:
"""
Creates new dataframe and schema objects to reflect excluded column names
and discovered features.
and discovered features. This also normalizes dataframe columns to ensure a
standard set of columns (i.e. timestamp and prediction_id) and datatypes for
those columns.
"""
included_column_names: List[str] = []
for column_name in dataframe.columns:
if column_name_to_include.get(str(column_name), False):
included_column_names.append(str(column_name))
parsed_dataframe = dataframe[included_column_names]
parsed_dataframe = dataframe[included_column_names].copy()
parsed_schema = replace(schema, excludes=None, **schema_patch)

ts_col_name = parsed_schema.timestamp_column_name
if ts_col_name is None:
now = Timestamp.utcnow()
parsed_schema = replace(parsed_schema, timestamp_column_name="timestamp")
parsed_dataframe["timestamp"] = now
elif is_numeric_dtype(dataframe.dtypes[ts_col_name]):
parsed_dataframe[ts_col_name] = parsed_dataframe[ts_col_name].apply(
lambda x: to_datetime(x, unit="ms")
)

pred_col_name = parsed_schema.prediction_id_column_name
if pred_col_name is None:
parsed_schema = replace(parsed_schema, prediction_id_column_name="prediction_id")
parsed_dataframe["prediction_id"] = parsed_dataframe.apply(lambda _: str(uuid.uuid4()))
elif is_numeric_dtype(parsed_dataframe.dtypes[pred_col_name]):
parsed_dataframe[pred_col_name] = parsed_dataframe[pred_col_name].astype(str)

return parsed_dataframe, parsed_schema
25 changes: 23 additions & 2 deletions src/phoenix/datasets/errors.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from abc import ABC, abstractmethod
from abc import abstractmethod
from typing import Iterable, List, Union


class ValidationError(ABC):
class ValidationError(Exception):
def __repr__(self) -> str:
return self.__class__.__name__

Expand Down Expand Up @@ -42,13 +42,34 @@ def error_message(self) -> str:
)


class InvalidSchemaError(ValidationError):
def __repr__(self) -> str:
return self.__class__.__name__

def __init__(self, invalid_props: Iterable[str]) -> None:
self.invalid_props = invalid_props

def error_message(self) -> str:
return "The schema is invalid: " f"{', '.join(map(str, self.invalid_props))}."
nate-mar marked this conversation as resolved.
Show resolved Hide resolved


class DatasetError(Exception):
"""An error raised when the dataset is invalid or incomplete"""

def __init__(self, errors: Union[ValidationError, List[ValidationError]]):
self.errors = errors


class InvalidColumnType(ValidationError):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does DatasetError only inherit from the BaseException (above) and not ValidationError as well?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't remember the rationale but I think the thought was to have a set of errors that made it clear which was to blame - the dataframe being malformed or the schema being mis-configured. Feel free to fix the inheritance as it makes sense.

"""An error raised when the column type is invalid"""

def __init__(self, error_msgs: Iterable[str]) -> None:
self.error_msgs = error_msgs

def error_message(self) -> str:
return f"Invalid column types: {self.error_msgs}"


class MissingField(ValidationError):
"""An error raised when trying to access a field that is absent from the Schema"""

Expand Down
2 changes: 1 addition & 1 deletion src/phoenix/datasets/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class Schema(Dict[SchemaFieldName, SchemaFieldValue]):

def to_json(self) -> str:
"Converts the schema to a dict for JSON serialization"
dictionary = self.__dict__
dictionary = {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good find.


for field in self.__dataclass_fields__:
value = getattr(self, field)
Expand Down
59 changes: 57 additions & 2 deletions src/phoenix/datasets/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,71 @@
from typing import List

from pandas import DataFrame
from pandas.api.types import is_datetime64_any_dtype as is_datetime
from pandas.api.types import is_numeric_dtype, is_string_dtype

from . import errors as err
from .schema import Schema


def _check_valid_schema(schema: Schema) -> List[err.ValidationError]:
errs: List[str] = []
if schema.excludes is None:
return []

if schema.timestamp_column_name in schema.excludes:
errs.append(
f"{schema.timestamp_column_name} cannot be excluded because "
f"it is already being used as the timestamp column"
)

if schema.prediction_id_column_name in schema.excludes:
errs.append(
f"{schema.prediction_id_column_name} cannot be excluded because "
f"it is already being used as the prediction id column"
)

if len(errs) > 0:
return [err.InvalidSchemaError(errs)]

return []


def validate_dataset_inputs(dataframe: DataFrame, schema: Schema) -> List[err.ValidationError]:
general_checks = chain(check_missing_columns(dataframe, schema))
general_checks = chain(
_check_missing_columns(dataframe, schema),
_check_column_types(dataframe, schema),
_check_valid_schema(schema),
)
return list(general_checks)


def check_missing_columns(dataframe: DataFrame, schema: Schema) -> List[err.MissingColumns]:
def _check_column_types(dataframe: DataFrame, schema: Schema) -> List[err.ValidationError]:
wrong_type_cols: List[str] = []
if schema.timestamp_column_name is not None:
if not (
is_numeric_dtype(dataframe.dtypes[schema.timestamp_column_name])
or is_datetime(dataframe.dtypes[schema.timestamp_column_name])
):
wrong_type_cols.append(
f"{schema.timestamp_column_name} should be of timestamp or numeric type"
)

if schema.prediction_id_column_name is not None:
if not (
is_numeric_dtype(dataframe.dtypes[schema.prediction_id_column_name])
or is_string_dtype(dataframe.dtypes[schema.prediction_id_column_name])
):
wrong_type_cols.append(
f"{schema.prediction_id_column_name} should be a string or numeric type"
)

if len(wrong_type_cols) > 0:
return [err.InvalidColumnType(wrong_type_cols)]
return []


def _check_missing_columns(dataframe: DataFrame, schema: Schema) -> List[err.MissingColumns]:
# converting to a set first makes the checks run a lot faster
existing_columns = set(dataframe.columns)
missing_columns = []
Expand Down Expand Up @@ -45,4 +99,5 @@ def check_missing_columns(dataframe: DataFrame, schema: Schema) -> List[err.Miss

if missing_columns:
return [err.MissingColumns(missing_columns)]

return []
Loading