Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: rename excludes schema field #460

Merged
merged 1 commit into from
Mar 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions src/phoenix/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,16 +301,15 @@ def _parse_dataframe_and_schema(dataframe: DataFrame, schema: Schema) -> Tuple[D
they are not explicitly provided, and removes excluded column names from
both dataframe and schema.

Removes column names in `schema.excludes` from the input dataframe and
schema. To remove an embedding feature and all associated columns, add the
name of the embedding feature to `schema.excludes` rather than the
associated column names. If `schema.feature_column_names` is `None`,
automatically discovers features by adding all column names present in the
dataframe but not included in any other schema fields.
Removes column names in `schema.excluded_column_names` from the input dataframe and schema. To
remove an embedding feature and all associated columns, add the name of the embedding feature to
`schema.excluded_column_names` rather than the associated column names. If
`schema.feature_column_names` is `None`, automatically discovers features by adding all column
names present in the dataframe but not included in any other schema fields.
"""

unseen_excluded_column_names: Set[str] = (
set(schema.excludes) if schema.excludes is not None else set()
set(schema.excluded_column_names) if schema.excluded_column_names is not None else set()
)
unseen_column_names: Set[str] = set(dataframe.columns.to_list())
column_name_to_include: Dict[str, bool] = {}
Expand Down Expand Up @@ -509,7 +508,7 @@ def _create_and_normalize_dataframe_and_schema(
if column_name_to_include.get(str(column_name), False):
included_column_names.append(str(column_name))
parsed_dataframe = dataframe[included_column_names].copy()
parsed_schema = replace(schema, excludes=None, **schema_patch)
parsed_schema = replace(schema, excluded_column_names=None, **schema_patch)
pred_col_name = parsed_schema.prediction_id_column_name
if pred_col_name is None:
parsed_schema = replace(parsed_schema, prediction_id_column_name="prediction_id")
Expand Down
2 changes: 1 addition & 1 deletion src/phoenix/datasets/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class Schema(Dict[SchemaFieldName, SchemaFieldValue]):
actual_label_column_name: Optional[str] = None
actual_score_column_name: Optional[str] = None
embedding_feature_column_names: Optional[EmbeddingFeatures] = None
excludes: Optional[List[str]] = None
excluded_column_names: Optional[List[str]] = None

def to_json(self) -> str:
"Converts the schema to a dict for JSON serialization"
Expand Down
6 changes: 3 additions & 3 deletions src/phoenix/datasets/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,16 @@

def _check_valid_schema(schema: Schema) -> List[err.ValidationError]:
errs: List[str] = []
if schema.excludes is None:
if schema.excluded_column_names is None:
return []

if schema.timestamp_column_name in schema.excludes:
if schema.timestamp_column_name in schema.excluded_column_names:
errs.append(
f"{schema.timestamp_column_name} cannot be excluded because "
f"it is already being used as the timestamp column"
)

if schema.prediction_id_column_name in schema.excludes:
if schema.prediction_id_column_name in schema.excluded_column_names:
errs.append(
f"{schema.prediction_id_column_name} cannot be excluded because "
f"it is already being used as the prediction id column"
Expand Down
50 changes: 25 additions & 25 deletions tests/datasets/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def test_some_features_excluded_removes_excluded_features_columns_and_keeps_the_
feature_column_names=["feature0", "feature1"],
tag_column_names=["tag0"],
prediction_label_column_name="prediction_label",
excludes=["feature1"],
excluded_column_names=["feature1"],
)
self._parse_dataframe_and_schema_and_check_output(
input_dataframe=input_dataframe,
Expand All @@ -121,7 +121,7 @@ def test_some_features_excluded_removes_excluded_features_columns_and_keeps_the_
prediction_label_column_name="prediction_label",
feature_column_names=["feature0"],
tag_column_names=["tag0"],
excludes=None,
excluded_column_names=None,
),
should_log_warning_to_user=False,
caplog=caplog,
Expand All @@ -140,14 +140,14 @@ def test_all_features_and_tags_excluded_sets_schema_features_and_tags_fields_to_
"tag0": ["tag" for _ in range(self.num_records)],
}
)
excludes = ["feature0", "feature1", "tag0"]
excluded_column_names = ["feature0", "feature1", "tag0"]
input_schema = Schema(
prediction_id_column_name="prediction_id",
timestamp_column_name="timestamp",
feature_column_names=["feature0", "feature1"],
tag_column_names=["tag0"],
prediction_label_column_name="prediction_label",
excludes=excludes,
excluded_column_names=excluded_column_names,
)
self._parse_dataframe_and_schema_and_check_output(
input_dataframe=input_dataframe,
Expand All @@ -160,7 +160,7 @@ def test_all_features_and_tags_excluded_sets_schema_features_and_tags_fields_to_
prediction_label_column_name="prediction_label",
feature_column_names=None,
tag_column_names=None,
excludes=None,
excluded_column_names=None,
),
should_log_warning_to_user=False,
caplog=caplog,
Expand All @@ -181,7 +181,7 @@ def test_excluded_single_column_schema_fields_set_to_none(self, caplog):
timestamp_column_name="timestamp",
prediction_label_column_name="prediction_label",
feature_column_names=["feature0", "feature1"],
excludes=["prediction_label"],
excluded_column_names=["prediction_label"],
)
self._parse_dataframe_and_schema_and_check_output(
input_dataframe=input_dataframe,
Expand All @@ -192,13 +192,13 @@ def test_excluded_single_column_schema_fields_set_to_none(self, caplog):
expected_parsed_schema=replace(
input_schema,
prediction_label_column_name=None,
excludes=None,
excluded_column_names=None,
),
should_log_warning_to_user=False,
caplog=caplog,
)

def test_no_input_schema_features_and_no_excludes_discovers_features(self, caplog):
def test_no_input_schema_features_and_no_excluded_column_names_discovers_features(self, caplog):
input_dataframe = DataFrame(
{
"prediction_id": [str(x) for x in range(self.num_records)],
Expand All @@ -225,7 +225,7 @@ def test_no_input_schema_features_and_no_excludes_discovers_features(self, caplo
caplog=caplog,
)

def test_no_input_schema_features_and_list_of_excludes_discovers_non_excluded_features(
def test_no_input_schema_features_and_nonempty_excluded_column_names_discovers_features(
self, caplog
):
input_dataframe = DataFrame(
Expand All @@ -240,13 +240,13 @@ def test_no_input_schema_features_and_list_of_excludes_discovers_non_excluded_fe
"tag1": ["tag1" for _ in range(self.num_records)],
}
)
excludes = ["prediction_label", "feature1", "tag0"]
excluded_column_names = ["prediction_label", "feature1", "tag0"]
input_schema = Schema(
prediction_id_column_name="prediction_id",
timestamp_column_name="timestamp",
tag_column_names=["tag0", "tag1"],
prediction_label_column_name="prediction_label",
excludes=excludes,
excluded_column_names=excluded_column_names,
)
self._parse_dataframe_and_schema_and_check_output(
input_dataframe=input_dataframe,
Expand All @@ -259,7 +259,7 @@ def test_no_input_schema_features_and_list_of_excludes_discovers_non_excluded_fe
prediction_label_column_name=None,
feature_column_names=["feature0", "feature2"],
tag_column_names=["tag1"],
excludes=None,
excluded_column_names=None,
),
should_log_warning_to_user=False,
caplog=caplog,
Expand All @@ -278,14 +278,14 @@ def test_excluded_column_not_contained_in_dataframe_logs_warning(self, caplog):
"tag1": ["tag1" for _ in range(self.num_records)],
}
)
excludes = ["prediction_label", "column_not_in_dataframe"]
excluded_column_names = ["prediction_label", "column_not_in_dataframe"]
input_schema = Schema(
prediction_id_column_name="prediction_id",
timestamp_column_name="timestamp",
feature_column_names=["feature0", "feature1", "feature2"],
tag_column_names=["tag0", "tag1"],
prediction_label_column_name="prediction_label",
excludes=excludes,
excluded_column_names=excluded_column_names,
)
self._parse_dataframe_and_schema_and_check_output(
input_dataframe=input_dataframe,
Expand All @@ -294,7 +294,7 @@ def test_excluded_column_not_contained_in_dataframe_logs_warning(self, caplog):
["prediction_id", "timestamp", "feature0", "feature1", "feature2", "tag0", "tag1"]
],
expected_parsed_schema=replace(
input_schema, prediction_label_column_name=None, excludes=None
input_schema, prediction_label_column_name=None, excluded_column_names=None
),
should_log_warning_to_user=True,
caplog=caplog,
Expand Down Expand Up @@ -362,7 +362,7 @@ def test_embedding_columns_of_excluded_embedding_feature_are_removed(self, caplo
raw_data_column_name="raw_data_column1",
),
},
excludes=["embedding_feature0"],
excluded_column_names=["embedding_feature0"],
)
self._parse_dataframe_and_schema_and_check_output(
input_dataframe=input_dataframe,
Expand All @@ -385,7 +385,7 @@ def test_embedding_columns_of_excluded_embedding_feature_are_removed(self, caplo
raw_data_column_name="raw_data_column1",
)
},
excludes=None,
excluded_column_names=None,
),
should_log_warning_to_user=False,
caplog=caplog,
Expand Down Expand Up @@ -413,7 +413,7 @@ def test_excluding_all_embedding_features_sets_schema_embedding_field_to_none(se
raw_data_column_name="raw_data_column0",
),
},
excludes=["embedding_feature0"],
excluded_column_names=["embedding_feature0"],
)
self._parse_dataframe_and_schema_and_check_output(
input_dataframe=input_dataframe,
Expand All @@ -422,7 +422,7 @@ def test_excluding_all_embedding_features_sets_schema_embedding_field_to_none(se
expected_parsed_schema=replace(
input_schema,
embedding_feature_column_names=None,
excludes=None,
excluded_column_names=None,
),
should_log_warning_to_user=False,
caplog=caplog,
Expand Down Expand Up @@ -452,15 +452,15 @@ def test_excluding_an_embedding_column_rather_than_the_embedding_feature_name_lo
raw_data_column_name="raw_data_column0",
),
},
excludes=["embedding_vector0"],
excluded_column_names=["embedding_vector0"],
)
self._parse_dataframe_and_schema_and_check_output(
input_dataframe=input_dataframe,
input_schema=input_schema,
expected_parsed_dataframe=input_dataframe,
expected_parsed_schema=replace(
input_schema,
excludes=None,
excluded_column_names=None,
),
should_log_warning_to_user=True,
caplog=caplog,
Expand Down Expand Up @@ -489,7 +489,7 @@ def test_excluding_embedding_feature_with_same_name_as_embedding_column_does_not
raw_data_column_name="raw_data_column0",
),
},
excludes=["embedding0"],
excluded_column_names=["embedding0"],
)
self._parse_dataframe_and_schema_and_check_output(
input_dataframe=input_dataframe,
Expand All @@ -498,7 +498,7 @@ def test_excluding_embedding_feature_with_same_name_as_embedding_column_does_not
expected_parsed_schema=replace(
input_schema,
embedding_feature_column_names=None,
excludes=None,
excluded_column_names=None,
),
should_log_warning_to_user=False,
caplog=caplog,
Expand Down Expand Up @@ -702,7 +702,7 @@ def test_dataset_validate_invalid_schema_excludes_timestamp(self) -> None:
timestamp_column_name="timestamp",
feature_column_names=["feature0"],
prediction_label_column_name="prediction_label",
excludes=["timestamp"],
excluded_column_names=["timestamp"],
)

with raises(DatasetError):
Expand All @@ -722,7 +722,7 @@ def test_dataset_validate_invalid_schema_excludes_prediction_id(self) -> None:
prediction_id_column_name="prediction_id",
feature_column_names=["feature0"],
prediction_label_column_name="prediction_label",
excludes=["prediction_id"],
excluded_column_names=["prediction_id"],
)

with raises(DatasetError):
Expand Down