Skip to content

Commit

Permalink
feat(spark): added reference statistics for regression (#55)
Browse files Browse the repository at this point in the history
* feat: added reference statistics for regression

* fix: added serialize_as_any=True in tests

---------

Co-authored-by: lorenzodagostinoradicalbit <lorenzo.dagostino@radicalbit.ai>
  • Loading branch information
lorenzodagostinoradicalbit and lorenzodagostinoradicalbit authored Jul 1, 2024
1 parent c05b133 commit cdff426
Show file tree
Hide file tree
Showing 11 changed files with 834 additions and 775 deletions.
8 changes: 6 additions & 2 deletions spark/jobs/current_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,18 @@ def main(
complete_record["MODEL_QUALITY"] = orjson.dumps(model_quality).decode(
"utf-8"
)
complete_record["STATISTICS"] = orjson.dumps(statistics).decode("utf-8")
complete_record["STATISTICS"] = statistics.model_dump_json(
serialize_as_any=True
)
complete_record["DATA_QUALITY"] = data_quality.model_dump_json(
serialize_as_any=True
)
complete_record["DRIFT"] = orjson.dumps(drift).decode("utf-8")
case ModelType.MULTI_CLASS:
statistics = calculate_statistics_current(current_dataset)
complete_record["STATISTICS"] = orjson.dumps(statistics).decode("utf-8")
complete_record["STATISTICS"] = statistics.model_dump_json(
serialize_as_any=True
)

schema = StructType(
[
Expand Down
10 changes: 6 additions & 4 deletions spark/jobs/metrics/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from models.reference_dataset import ReferenceDataset
import pyspark.sql.functions as F

from models.statistics import Statistics

N_VARIABLES = "n_variables"
N_OBSERVATION = "n_observations"
MISSING_CELLS = "missing_cells"
Expand All @@ -17,7 +19,7 @@
# FIXME generalize to one method
def calculate_statistics_reference(
reference_dataset: ReferenceDataset,
) -> dict[str, float]:
) -> Statistics:
number_of_variables = len(reference_dataset.get_all_variables())
number_of_observations = reference_dataset.reference_count
number_of_numerical = len(reference_dataset.get_numerical_variables())
Expand Down Expand Up @@ -79,12 +81,12 @@ def calculate_statistics_reference(
.to_dict(orient="records")[0]
)

return stats
return Statistics(**stats)


def calculate_statistics_current(
current_dataset: CurrentDataset,
) -> dict[str, float]:
) -> Statistics:
number_of_variables = len(current_dataset.get_all_variables())
number_of_observations = current_dataset.current_count
number_of_numerical = len(current_dataset.get_numerical_variables())
Expand Down Expand Up @@ -146,4 +148,4 @@ def calculate_statistics_current(
.to_dict(orient="records")[0]
)

return stats
return Statistics(**stats)
6 changes: 5 additions & 1 deletion spark/jobs/models/reference_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@

from pyspark.ml.feature import StringIndexer
from pyspark.sql import DataFrame
from pyspark.sql.types import DoubleType, StructField, StructType
from pyspark.sql.types import (
DoubleType,
StructField,
StructType,
)

from utils.models import ModelOut, ModelType, ColumnDefinition
from utils.spark import apply_schema_to_dataframe
Expand Down
17 changes: 17 additions & 0 deletions spark/jobs/models/statistics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from pydantic import BaseModel, ConfigDict

from typing import Optional


class Statistics(BaseModel):
n_variables: int
n_observations: int
missing_cells: int
missing_cells_perc: Optional[float]
duplicate_rows: int
duplicate_rows_perc: Optional[float]
numeric: int
categorical: int
datetime: int

model_config = ConfigDict(ser_json_inf_nan="null")
13 changes: 11 additions & 2 deletions spark/jobs/reference_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ def main(
complete_record["MODEL_QUALITY"] = orjson.dumps(model_quality).decode(
"utf-8"
)
complete_record["STATISTICS"] = orjson.dumps(statistics).decode("utf-8")
complete_record["STATISTICS"] = statistics.model_dump_json(
serialize_as_any=True
)
complete_record["DATA_QUALITY"] = data_quality.model_dump_json(
serialize_as_any=True
)
Expand All @@ -73,7 +75,9 @@ def main(
statistics = calculate_statistics_reference(reference_dataset)
data_quality = metrics_service.calculate_data_quality()
model_quality = metrics_service.calculate_model_quality()
complete_record["STATISTICS"] = orjson.dumps(statistics).decode("utf-8")
complete_record["STATISTICS"] = statistics.model_dump_json(
serialize_as_any=True
)
complete_record["DATA_QUALITY"] = data_quality.model_dump_json(
serialize_as_any=True
)
Expand All @@ -84,7 +88,12 @@ def main(
metrics_service = ReferenceMetricsRegressionService(
reference=reference_dataset
)
statistics = calculate_statistics_reference(reference_dataset)
model_quality = metrics_service.calculate_model_quality()

complete_record["STATISTICS"] = statistics.model_dump_json(
serialize_as_any=True
)
complete_record["MODEL_QUALITY"] = model_quality.model_dump_json(
serialize_as_any=True
)
Expand Down
22 changes: 11 additions & 11 deletions spark/tests/binary_current_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def test_calculation(spark_fixture, dataset):
data_quality = metrics_service.calculate_data_quality()
model_quality = metrics_service.calculate_model_quality_with_group_by_timestamp()

assert stats == my_approx(
assert stats.model_dump(serialize_as_any=True) == my_approx(
{
"categorical": 2,
"datetime": 1,
Expand Down Expand Up @@ -431,7 +431,7 @@ def test_calculation_current_joined(spark_fixture, current_joined):
stats = calculate_statistics_current(current_dataset)
data_quality = metrics_service.calculate_data_quality()

assert stats == my_approx(
assert stats.model_dump(serialize_as_any=True) == my_approx(
{
"categorical": 1,
"datetime": 1,
Expand Down Expand Up @@ -824,7 +824,7 @@ def test_calculation_complete(spark_fixture, complete_dataset):
stats = calculate_statistics_current(current_dataset)
data_quality = metrics_service.calculate_data_quality()

assert stats == my_approx(
assert stats.model_dump(serialize_as_any=True) == my_approx(
{
"missing_cells": 0,
"missing_cells_perc": 0.0,
Expand Down Expand Up @@ -973,7 +973,7 @@ def test_calculation_easy_dataset(spark_fixture, easy_dataset):
stats = calculate_statistics_current(current_dataset)
data_quality = metrics_service.calculate_data_quality()

assert stats == my_approx(
assert stats.model_dump(serialize_as_any=True) == my_approx(
{
"missing_cells": 0,
"missing_cells_perc": 0.0,
Expand Down Expand Up @@ -1122,7 +1122,7 @@ def test_calculation_dataset_cat_missing(spark_fixture, dataset_cat_missing):
stats = calculate_statistics_current(current_dataset)
data_quality = metrics_service.calculate_data_quality()

assert stats == my_approx(
assert stats.model_dump(serialize_as_any=True) == my_approx(
{
"missing_cells": 5,
"missing_cells_perc": 6.25,
Expand Down Expand Up @@ -1286,7 +1286,7 @@ def test_calculation_dataset_with_datetime(spark_fixture, dataset_with_datetime)
stats = calculate_statistics_current(current_dataset)
data_quality = metrics_service.calculate_data_quality()

assert stats == my_approx(
assert stats.model_dump(serialize_as_any=True) == my_approx(
{
"categorical": 2,
"datetime": 1,
Expand Down Expand Up @@ -1450,7 +1450,7 @@ def test_calculation_easy_dataset_bucket_test(spark_fixture, easy_dataset_bucket
stats = calculate_statistics_current(current_dataset)
data_quality = metrics_service.calculate_data_quality()

assert stats == my_approx(
assert stats.model_dump(serialize_as_any=True) == my_approx(
{
"missing_cells": 0,
"missing_cells_perc": 0.0,
Expand Down Expand Up @@ -1628,7 +1628,7 @@ def test_calculation_for_hour(spark_fixture, dataset_for_hour):
data_quality = metrics_service.calculate_data_quality()
model_quality = metrics_service.calculate_model_quality_with_group_by_timestamp()

assert stats == my_approx(
assert stats.model_dump(serialize_as_any=True) == my_approx(
{
"categorical": 2,
"datetime": 1,
Expand Down Expand Up @@ -1921,7 +1921,7 @@ def test_calculation_for_day(spark_fixture, dataset_for_day):
data_quality = metrics_service.calculate_data_quality()
model_quality = metrics_service.calculate_model_quality_with_group_by_timestamp()

assert stats == my_approx(
assert stats.model_dump(serialize_as_any=True) == my_approx(
{
"categorical": 2,
"datetime": 1,
Expand Down Expand Up @@ -2200,7 +2200,7 @@ def test_calculation_for_week(spark_fixture, dataset_for_week):
data_quality = metrics_service.calculate_data_quality()
model_quality = metrics_service.calculate_model_quality_with_group_by_timestamp()

assert stats == my_approx(
assert stats.model_dump(serialize_as_any=True) == my_approx(
{
"categorical": 2,
"datetime": 1,
Expand Down Expand Up @@ -2479,7 +2479,7 @@ def test_calculation_for_month(spark_fixture, dataset_for_month):
data_quality = metrics_service.calculate_data_quality()
model_quality = metrics_service.calculate_model_quality_with_group_by_timestamp()

assert stats == my_approx(
assert stats.model_dump(serialize_as_any=True) == my_approx(
{
"categorical": 2,
"datetime": 1,
Expand Down
16 changes: 8 additions & 8 deletions spark/tests/binary_reference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def test_calculation(spark_fixture, dataset):
model_quality = metrics_service.calculate_model_quality()
data_quality = metrics_service.calculate_data_quality()

assert stats == my_approx(
assert stats.model_dump(serialize_as_any=True) == my_approx(
{
"categorical": 2,
"datetime": 1,
Expand Down Expand Up @@ -300,7 +300,7 @@ def test_calculation_reference_joined(spark_fixture, reference_joined):
model_quality = metrics_service.calculate_model_quality()
data_quality = metrics_service.calculate_data_quality()

assert stats == my_approx(
assert stats.model_dump(serialize_as_any=True) == my_approx(
{
"categorical": 1,
"datetime": 1,
Expand Down Expand Up @@ -696,7 +696,7 @@ def test_calculation_complete(spark_fixture, complete_dataset):
model_quality = metrics_service.calculate_model_quality()
data_quality = metrics_service.calculate_data_quality()

assert stats == my_approx(
assert stats.model_dump(serialize_as_any=True) == my_approx(
{
"missing_cells": 0,
"missing_cells_perc": 0.0,
Expand Down Expand Up @@ -851,7 +851,7 @@ def test_calculation_easy_dataset(spark_fixture, easy_dataset):
model_quality = metrics_service.calculate_model_quality()
data_quality = metrics_service.calculate_data_quality()

assert stats == my_approx(
assert stats.model_dump(serialize_as_any=True) == my_approx(
{
"missing_cells": 0,
"missing_cells_perc": 0.0,
Expand Down Expand Up @@ -1005,7 +1005,7 @@ def test_calculation_dataset_cat_missing(spark_fixture, dataset_cat_missing):
model_quality = metrics_service.calculate_model_quality()
data_quality = metrics_service.calculate_data_quality()

assert stats == my_approx(
assert stats.model_dump(serialize_as_any=True) == my_approx(
{
"missing_cells": 5,
"missing_cells_perc": 6.25,
Expand Down Expand Up @@ -1182,7 +1182,7 @@ def test_calculation_dataset_with_datetime(spark_fixture, dataset_with_datetime)
model_quality = metrics_service.calculate_model_quality()
data_quality = metrics_service.calculate_data_quality()

assert stats == my_approx(
assert stats.model_dump(serialize_as_any=True) == my_approx(
{
"categorical": 2,
"datetime": 1,
Expand Down Expand Up @@ -1365,7 +1365,7 @@ def test_calculation_enhanced_data(spark_fixture, enhanced_data):
model_quality = metrics_service.calculate_model_quality()
data_quality = metrics_service.calculate_data_quality()

assert stats == my_approx(
assert stats.model_dump(serialize_as_any=True) == my_approx(
{
"missing_cells": 2996,
"missing_cells_perc": 0.6241666666666668,
Expand Down Expand Up @@ -1905,7 +1905,7 @@ def test_calculation_dataset_bool_missing(spark_fixture, dataset_bool_missing):
model_quality = metrics_service.calculate_model_quality()
data_quality = metrics_service.calculate_data_quality()

assert stats == my_approx(
assert stats.model_dump(serialize_as_any=True) == my_approx(
{
"missing_cells": 5,
"missing_cells_perc": 6.25,
Expand Down
6 changes: 3 additions & 3 deletions spark/tests/multiclass_current_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def test_calculation_dataset_target_int(spark_fixture, dataset_target_int):

stats = calculate_statistics_current(current_dataset)

assert stats == my_approx(
assert stats.model_dump(serialize_as_any=True) == my_approx(
{
"categorical": 2,
"datetime": 1,
Expand Down Expand Up @@ -148,7 +148,7 @@ def test_calculation_dataset_target_string(spark_fixture, dataset_target_string)

stats = calculate_statistics_current(current_dataset)

assert stats == my_approx(
assert stats.model_dump(serialize_as_any=True) == my_approx(
{
"categorical": 4,
"datetime": 1,
Expand Down Expand Up @@ -200,7 +200,7 @@ def test_calculation_dataset_perfect_classes(spark_fixture, dataset_perfect_clas

stats = calculate_statistics_current(current_dataset)

assert stats == my_approx(
assert stats.model_dump(serialize_as_any=True) == my_approx(
{
"categorical": 4,
"datetime": 1,
Expand Down
6 changes: 3 additions & 3 deletions spark/tests/multiclass_reference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def test_calculation_dataset_target_int(spark_fixture, dataset_target_int):
data_quality = multiclass_service.calculate_data_quality()
model_quality = multiclass_service.calculate_model_quality()

assert stats == my_approx(
assert stats.model_dump(serialize_as_any=True) == my_approx(
{
"categorical": 2,
"datetime": 1,
Expand Down Expand Up @@ -302,7 +302,7 @@ def test_calculation_dataset_target_string(spark_fixture, dataset_target_string)
data_quality = multiclass_service.calculate_data_quality()
model_quality = multiclass_service.calculate_model_quality()

assert stats == my_approx(
assert stats.model_dump(serialize_as_any=True) == my_approx(
{
"categorical": 4,
"datetime": 1,
Expand Down Expand Up @@ -521,7 +521,7 @@ def test_calculation_dataset_perfect_classes(spark_fixture, dataset_perfect_clas
data_quality = multiclass_service.calculate_data_quality()
model_quality = multiclass_service.calculate_model_quality()

assert stats == my_approx(
assert stats.model_dump(serialize_as_any=True) == my_approx(
{
"categorical": 4,
"datetime": 1,
Expand Down
Loading

0 comments on commit cdff426

Please sign in to comment.