diff --git a/.github/workflows/trivy-scan.yaml b/.github/workflows/trivy-scan.yaml index 28e6e161..303c990b 100644 --- a/.github/workflows/trivy-scan.yaml +++ b/.github/workflows/trivy-scan.yaml @@ -54,7 +54,7 @@ jobs: scan-spark: needs: [changes] if: ${{ github.event_name == 'push' || ( github.event_name == 'pull_request' && contains(needs.changes.outputs.changed_files, 'spark/poetry.lock') ) }} - uses: radicalbit/radicalbit-github-workflows/.github/workflows/trivy-fs-scan.yaml@v1 + uses: radicalbit/radicalbit-github-workflows/.github/workflows/trivy-fs-scan.yaml@main with: directory: ./spark prcomment: ${{ github.event_name == 'pull_request' && contains(needs.changes.outputs.changed_files, 'spark/poetry.lock') }} diff --git a/.gitignore b/.gitignore index 903c57b3..4a5274e5 100644 --- a/.gitignore +++ b/.gitignore @@ -106,4 +106,5 @@ ui/*.sw? ## K3S SPECIFICS # ##################### -docker/k3s_data/kubeconfig/ \ No newline at end of file +docker/k3s_data/kubeconfig/ +docker/k3s_data/images/ \ No newline at end of file diff --git a/docker-compose.yaml b/docker-compose.yaml index 06975ad4..6a717779 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -150,6 +150,8 @@ services: - ./docker/k3s_data/manifests/spark-init.yaml:/var/lib/rancher/k3s/server/manifests/spark-init.yaml # Mount entrypoint - ./docker/k3s_data/init/entrypoint.sh:/opt/entrypoint/entrypoint.sh + # Preload docker images + - ./docker/k3s_data/images:/var/lib/rancher/k3s/agent/images expose: - "6443" # Kubernetes API Server - "80" # Ingress controller port 80 diff --git a/spark/Dockerfile b/spark/Dockerfile index dd48e838..718dca82 100644 --- a/spark/Dockerfile +++ b/spark/Dockerfile @@ -1,5 +1,16 @@ +FROM python:3.10.14-slim AS build + +WORKDIR /build +COPY poetry.lock pyproject.toml ./ +RUN pip install --no-cache-dir poetry==1.8.3 && \ + poetry export -f requirements.txt -o requirements.txt + + FROM spark:3.5.1-scala2.12-java17-python3-ubuntu +# Requirements from previous step +COPY --from=build /build/requirements.txt . + # Adding needed jar RUN curl -o /opt/spark/jars/bcprov-jdk15on-1.70.jar https://repo1.maven.org/maven2/org/bouncycastle/bcprov-jdk15on/1.70/bcprov-jdk15on-1.70.jar && \ curl -o /opt/spark/jars/bcpkix-jdk15on-1.70.jar https://repo1.maven.org/maven2/org/bouncycastle/bcpkix-jdk15on/1.70/bcpkix-jdk15on-1.70.jar && \ @@ -9,8 +20,11 @@ RUN curl -o /opt/spark/jars/bcprov-jdk15on-1.70.jar https://repo1.maven.org/mave USER root -# Adding needed python libs that will be used by pyspark jobs -RUN pip install numpy pydantic pandas psycopg2-binary orjson scipy +RUN apt-get update && \ + apt-get install -y --no-install-recommends gcc libpq-dev python3-dev + +# Install requirements coming from pyproject +RUN pip install --no-cache-dir -r requirements.txt USER spark diff --git a/spark/README.md b/spark/README.md index 53a30ec9..b7725326 100644 --- a/spark/README.md +++ b/spark/README.md @@ -2,15 +2,35 @@ This folder contains files to create the Spark docker image that will be used to calculate metrics. -The custom image is created using the `Dockerfile` and it is a base Spark image where are installed some additional dependencies and loaded with the custom jobs located in the `jobs` folder. - -To create an additional job, add a `.py` file in `jobs` folder (take as an example `reference_job.py` for the boilerplate) +The custom image is created using the `Dockerfile` and it is a base Spark image where are installed additional dependencies and loaded with custom jobs located in the `jobs` folder. ### Development This is a poetry project that can be used to develop and test the jobs before putting them in the docker image. -NB: if additional python dependencies are needed, pleas add them in `Dockerfile` accordingly, and not only in the `pyproject.toml` +To create an additional job, add a `.py` file in `jobs` folder (take as an example `reference_job.py` for the boilerplate) and write unit tests + +### End-to-end testing + +Before publishing the image is possible to test the platform with new development or improvement done in the spark image. + +From this project folder, run + +```bash +docker build . -t radicalbit-spark-py:develop && docker save radicalbit-spark-py:develop -o ../docker/k3s_data/images/radicalbit-spark-py:develop.tar +``` + +This will build and save the new image in `/docker/k3s_data/images/`. + +To use this image in the Radicalbit Platform, the docker compose must be modified adding the following environment variable in the `api` container: + +``` +SPARK_IMAGE: "radicalbit-spark-py:develop" +``` + +When the k3s cluster inside the docker compose will start, it will automatically load the saved image that can be used to test the code during the development. + +NB: when a new image is built and saved, the k3s container must be restarted #### Formatting and linting diff --git a/spark/jobs/current_job.py b/spark/jobs/current_job.py index 375edf29..9cc46507 100644 --- a/spark/jobs/current_job.py +++ b/spark/jobs/current_job.py @@ -6,9 +6,11 @@ import orjson from pyspark.sql.types import StructType, StructField, StringType +from jobs.metrics.statistics import calculate_statistics_current +from jobs.models.current_dataset import CurrentDataset +from jobs.models.reference_dataset import ReferenceDataset from utils.current import CurrentMetricsService from utils.models import JobStatus, ModelOut -from utils.spark import apply_schema_to_dataframe from utils.db import update_job_status, write_to_db from pyspark.sql import SparkSession @@ -42,22 +44,15 @@ def main( "fs.s3a.connection.ssl.enabled", "false" ) - current_schema = model.to_current_spark_schema() - current_dataset = spark_session.read.csv(current_dataset_path, header=True) - current_dataset = apply_schema_to_dataframe(current_dataset, current_schema) - current_dataset = current_dataset.select( - *[c for c in current_schema.names if c in current_dataset.columns] - ) - reference_schema = model.to_reference_spark_schema() - reference_dataset = spark_session.read.csv(reference_dataset_path, header=True) - reference_dataset = apply_schema_to_dataframe(reference_dataset, reference_schema) - reference_dataset = reference_dataset.select( - *[c for c in reference_schema.names if c in reference_dataset.columns] - ) + raw_current = spark_session.read.csv(current_dataset_path, header=True) + current_dataset = CurrentDataset(model=model, raw_dataframe=raw_current) + raw_reference = spark_session.read.csv(reference_dataset_path, header=True) + reference_dataset = ReferenceDataset(model=model, raw_dataframe=raw_reference) + metrics_service = CurrentMetricsService( - spark_session, current_dataset, reference_dataset, model=model + spark_session, current_dataset.current, reference_dataset.reference, model=model ) - statistics = metrics_service.calculate_statistics() + statistics = calculate_statistics_current(current_dataset) data_quality = metrics_service.calculate_data_quality() model_quality = metrics_service.calculate_model_quality_with_group_by_timestamp() drift = metrics_service.calculate_drift() diff --git a/spark/jobs/metrics/__init__.py b/spark/jobs/metrics/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/spark/jobs/metrics/statistics.py b/spark/jobs/metrics/statistics.py new file mode 100644 index 00000000..ad3cae09 --- /dev/null +++ b/spark/jobs/metrics/statistics.py @@ -0,0 +1,148 @@ +from models.current_dataset import CurrentDataset +from models.reference_dataset import ReferenceDataset +import pyspark.sql.functions as F + +N_VARIABLES = "n_variables" +N_OBSERVATION = "n_observations" +MISSING_CELLS = "missing_cells" +MISSING_CELLS_PERC = "missing_cells_perc" +DUPLICATE_ROWS = "duplicate_rows" +DUPLICATE_ROWS_PERC = "duplicate_rows_perc" +NUMERIC = "numeric" +CATEGORICAL = "categorical" +DATETIME = "datetime" + + +# FIXME use pydantic struct like data quality +def calculate_statistics_reference( + reference_dataset: ReferenceDataset, +) -> dict[str, float]: + number_of_variables = len(reference_dataset.get_all_variables()) + number_of_observations = reference_dataset.reference_count + number_of_numerical = len(reference_dataset.get_numerical_variables()) + number_of_categorical = len(reference_dataset.get_categorical_variables()) + number_of_datetime = len(reference_dataset.get_datetime_variables()) + reference_columns = reference_dataset.reference.columns + + stats = ( + reference_dataset.reference.select( + [ + F.count(F.when(F.isnan(c) | F.col(c).isNull(), c)).alias(c) + if t not in ("datetime", "date", "timestamp", "bool", "boolean") + else F.count(F.when(F.col(c).isNull(), c)).alias(c) + for c, t in reference_dataset.reference.dtypes + ] + ) + .withColumn(MISSING_CELLS, sum([F.col(c) for c in reference_columns])) + .withColumn( + MISSING_CELLS_PERC, + (F.col(MISSING_CELLS) / (number_of_variables * number_of_observations)) + * 100, + ) + .withColumn( + DUPLICATE_ROWS, + F.lit( + number_of_observations + - reference_dataset.reference.dropDuplicates( + [ + c + for c in reference_columns + if c != reference_dataset.model.timestamp.name + ] + ).count() + ), + ) + .withColumn( + DUPLICATE_ROWS_PERC, + (F.col(DUPLICATE_ROWS) / number_of_observations) * 100, + ) + .withColumn(N_VARIABLES, F.lit(number_of_variables)) + .withColumn(N_OBSERVATION, F.lit(number_of_observations)) + .withColumn(NUMERIC, F.lit(number_of_numerical)) + .withColumn(CATEGORICAL, F.lit(number_of_categorical)) + .withColumn(DATETIME, F.lit(number_of_datetime)) + .select( + *[ + MISSING_CELLS, + MISSING_CELLS_PERC, + DUPLICATE_ROWS, + DUPLICATE_ROWS_PERC, + N_VARIABLES, + N_OBSERVATION, + NUMERIC, + CATEGORICAL, + DATETIME, + ] + ) + .toPandas() + .to_dict(orient="records")[0] + ) + + return stats + + +def calculate_statistics_current( + current_dataset: CurrentDataset, +) -> dict[str, float]: + number_of_variables = len(current_dataset.get_all_variables()) + number_of_observations = current_dataset.current_count + number_of_numerical = len(current_dataset.get_numerical_variables()) + number_of_categorical = len(current_dataset.get_categorical_variables()) + number_of_datetime = len(current_dataset.get_datetime_variables()) + reference_columns = current_dataset.current.columns + + stats = ( + current_dataset.current.select( + [ + F.count(F.when(F.isnan(c) | F.col(c).isNull(), c)).alias(c) + if t not in ("datetime", "date", "timestamp", "bool", "boolean") + else F.count(F.when(F.col(c).isNull(), c)).alias(c) + for c, t in current_dataset.current.dtypes + ] + ) + .withColumn(MISSING_CELLS, sum([F.col(c) for c in reference_columns])) + .withColumn( + MISSING_CELLS_PERC, + (F.col(MISSING_CELLS) / (number_of_variables * number_of_observations)) + * 100, + ) + .withColumn( + DUPLICATE_ROWS, + F.lit( + number_of_observations + - current_dataset.current.dropDuplicates( + [ + c + for c in reference_columns + if c != current_dataset.model.timestamp.name + ] + ).count() + ), + ) + .withColumn( + DUPLICATE_ROWS_PERC, + (F.col(DUPLICATE_ROWS) / number_of_observations) * 100, + ) + .withColumn(N_VARIABLES, F.lit(number_of_variables)) + .withColumn(N_OBSERVATION, F.lit(number_of_observations)) + .withColumn(NUMERIC, F.lit(number_of_numerical)) + .withColumn(CATEGORICAL, F.lit(number_of_categorical)) + .withColumn(DATETIME, F.lit(number_of_datetime)) + .select( + *[ + MISSING_CELLS, + MISSING_CELLS_PERC, + DUPLICATE_ROWS, + DUPLICATE_ROWS_PERC, + N_VARIABLES, + N_OBSERVATION, + NUMERIC, + CATEGORICAL, + DATETIME, + ] + ) + .toPandas() + .to_dict(orient="records")[0] + ) + + return stats diff --git a/spark/jobs/models/__init__.py b/spark/jobs/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/spark/jobs/models/current_dataset.py b/spark/jobs/models/current_dataset.py new file mode 100644 index 00000000..07154bf1 --- /dev/null +++ b/spark/jobs/models/current_dataset.py @@ -0,0 +1,97 @@ +from typing import List + +from pyspark.sql import DataFrame +from pyspark.sql.types import DoubleType, StructField, StructType + +from utils.models import ModelOut, ModelType, ColumnDefinition +from utils.spark import apply_schema_to_dataframe + + +class CurrentDataset: + def __init__(self, model: ModelOut, raw_dataframe: DataFrame): + current_schema = self.spark_schema(model) + current_dataset = apply_schema_to_dataframe(raw_dataframe, current_schema) + + self.model = model + self.current = current_dataset.select( + *[c for c in current_schema.names if c in current_dataset.columns] + ) + self.current_count = self.current.count() + + # FIXME this must exclude target when we will have separate current and ground truth + @staticmethod + def spark_schema(model: ModelOut): + all_features = ( + model.features + [model.target] + [model.timestamp] + model.outputs.output + ) + if model.outputs.prediction_proba and model.model_type == ModelType.BINARY: + enforce_float = [ + model.target.name, + model.outputs.prediction.name, + model.outputs.prediction_proba.name, + ] + elif model.model_type == ModelType.BINARY: + enforce_float = [model.target.name, model.outputs.prediction.name] + else: + enforce_float = [] + return StructType( + [ + StructField( + name=feature.name, + dataType=model.convert_types(feature.type), + nullable=False, + ) + if feature.name not in enforce_float + else StructField( + name=feature.name, + dataType=DoubleType(), + nullable=False, + ) + for feature in all_features + ] + ) + + def get_numerical_features(self) -> List[ColumnDefinition]: + return [feature for feature in self.model.features if feature.is_numerical()] + + def get_categorical_features(self) -> List[ColumnDefinition]: + return [feature for feature in self.model.features if feature.is_categorical()] + + # FIXME this must exclude target when we will have separate current and ground truth + def get_numerical_variables(self) -> List[ColumnDefinition]: + all_features = ( + self.model.features + + [self.model.target] + + [self.model.timestamp] + + self.model.outputs.output + ) + return [feature for feature in all_features if feature.is_numerical()] + + # FIXME this must exclude target when we will have separate current and ground truth + def get_categorical_variables(self) -> List[ColumnDefinition]: + all_features = ( + self.model.features + + [self.model.target] + + [self.model.timestamp] + + self.model.outputs.output + ) + return [feature for feature in all_features if feature.is_categorical()] + + # FIXME this must exclude target when we will have separate current and ground truth + def get_datetime_variables(self) -> List[ColumnDefinition]: + all_features = ( + self.model.features + + [self.model.target] + + [self.model.timestamp] + + self.model.outputs.output + ) + return [feature for feature in all_features if feature.is_datetime()] + + # FIXME this must exclude target when we will have separate current and ground truth + def get_all_variables(self) -> List[ColumnDefinition]: + return ( + self.model.features + + [self.model.target] + + [self.model.timestamp] + + self.model.outputs.output + ) diff --git a/spark/jobs/models/reference_dataset.py b/spark/jobs/models/reference_dataset.py new file mode 100644 index 00000000..144f2b6e --- /dev/null +++ b/spark/jobs/models/reference_dataset.py @@ -0,0 +1,92 @@ +from typing import List + +from pyspark.sql import DataFrame +from pyspark.sql.types import DoubleType, StructField, StructType + +from utils.models import ModelOut, ModelType, ColumnDefinition +from utils.spark import apply_schema_to_dataframe + + +class ReferenceDataset: + def __init__(self, model: ModelOut, raw_dataframe: DataFrame): + reference_schema = self.spark_schema(model) + reference_dataset = apply_schema_to_dataframe(raw_dataframe, reference_schema) + + self.model = model + self.reference = reference_dataset.select( + *[c for c in reference_schema.names if c in reference_dataset.columns] + ) + self.reference_count = self.reference.count() + + @staticmethod + def spark_schema(model: ModelOut): + all_features = ( + model.features + [model.target] + [model.timestamp] + model.outputs.output + ) + if model.outputs.prediction_proba and model.model_type == ModelType.BINARY: + enforce_float = [ + model.target.name, + model.outputs.prediction.name, + model.outputs.prediction_proba.name, + ] + elif model.model_type == ModelType.BINARY: + enforce_float = [model.target.name, model.outputs.prediction.name] + else: + enforce_float = [] + return StructType( + [ + StructField( + name=feature.name, + dataType=model.convert_types(feature.type), + nullable=False, + ) + if feature.name not in enforce_float + else StructField( + name=feature.name, + dataType=DoubleType(), + nullable=False, + ) + for feature in all_features + ] + ) + + def get_numerical_features(self) -> List[ColumnDefinition]: + return [feature for feature in self.model.features if feature.is_numerical()] + + def get_categorical_features(self) -> List[ColumnDefinition]: + return [feature for feature in self.model.features if feature.is_categorical()] + + def get_numerical_variables(self) -> List[ColumnDefinition]: + all_features = ( + self.model.features + + [self.model.target] + + [self.model.timestamp] + + self.model.outputs.output + ) + return [feature for feature in all_features if feature.is_numerical()] + + def get_categorical_variables(self) -> List[ColumnDefinition]: + all_features = ( + self.model.features + + [self.model.target] + + [self.model.timestamp] + + self.model.outputs.output + ) + return [feature for feature in all_features if feature.is_categorical()] + + def get_datetime_variables(self) -> List[ColumnDefinition]: + all_features = ( + self.model.features + + [self.model.target] + + [self.model.timestamp] + + self.model.outputs.output + ) + return [feature for feature in all_features if feature.is_datetime()] + + def get_all_variables(self) -> List[ColumnDefinition]: + return ( + self.model.features + + [self.model.target] + + [self.model.timestamp] + + self.model.outputs.output + ) diff --git a/spark/jobs/reference_job.py b/spark/jobs/reference_job.py index bb611607..52a3e963 100644 --- a/spark/jobs/reference_job.py +++ b/spark/jobs/reference_job.py @@ -5,9 +5,10 @@ import orjson from pyspark.sql.types import StructField, StructType, StringType +from metrics.statistics import calculate_statistics_reference +from models.reference_dataset import ReferenceDataset from utils.reference import ReferenceMetricsService -from utils.models import JobStatus, ModelOut -from utils.spark import apply_schema_to_dataframe +from utils.models import JobStatus, ModelOut, ModelType from utils.db import update_job_status, write_to_db from pyspark.sql import SparkSession @@ -42,25 +43,29 @@ def main( "fs.s3a.connection.ssl.enabled", "false" ) - reference_schema = model.to_reference_spark_schema() - reference_dataset = spark_session.read.csv(reference_dataset_path, header=True) - reference_dataset = apply_schema_to_dataframe(reference_dataset, reference_schema) - reference_dataset = reference_dataset.select( - *[c for c in reference_schema.names if c in reference_dataset.columns] - ) - metrics_service = ReferenceMetricsService(reference_dataset, model=model) - model_quality = metrics_service.calculate_model_quality() - statistics = metrics_service.calculate_statistics() - data_quality = metrics_service.calculate_data_quality() - - # TODO put needed fields here - complete_record = { - "UUID": str(uuid.uuid4()), - "REFERENCE_UUID": reference_uuid, - "MODEL_QUALITY": orjson.dumps(model_quality).decode("utf-8"), - "STATISTICS": orjson.dumps(statistics).decode("utf-8"), - "DATA_QUALITY": data_quality.model_dump_json(serialize_as_any=True), - } + raw_dataframe = spark_session.read.csv(reference_dataset_path, header=True) + reference_dataset = ReferenceDataset(model=model, raw_dataframe=raw_dataframe) + + metrics_service = ReferenceMetricsService(reference_dataset.reference, model=model) + + complete_record = {"UUID": str(uuid.uuid4()), "REFERENCE_UUID": reference_uuid} + + match model.model_type: + case ModelType.BINARY: + model_quality = metrics_service.calculate_model_quality() + statistics = calculate_statistics_reference(reference_dataset) + data_quality = metrics_service.calculate_data_quality() + complete_record["MODEL_QUALITY"] = orjson.dumps(model_quality).decode( + "utf-8" + ) + complete_record["STATISTICS"] = orjson.dumps(statistics).decode("utf-8") + complete_record["DATA_QUALITY"] = data_quality.model_dump_json( + serialize_as_any=True + ) + case ModelType.MULTI_CLASS: + # TODO add data quality and model quality + statistics = calculate_statistics_reference(reference_dataset) + complete_record["STATISTICS"] = orjson.dumps(statistics).decode("utf-8") schema = StructType( [ diff --git a/spark/jobs/utils/current.py b/spark/jobs/utils/current.py index 41f1058c..40b4154e 100644 --- a/spark/jobs/utils/current.py +++ b/spark/jobs/utils/current.py @@ -8,7 +8,7 @@ ) from pyspark.ml.feature import Bucketizer from pyspark.sql import DataFrame, SparkSession -from pyspark.sql.functions import count, when, isnan, col +from pyspark.sql.functions import col import pyspark.sql.functions as f from pyspark.sql.types import IntegerType @@ -71,72 +71,6 @@ def __init__( self.reference_count = self.reference.count() self.model = model - # FIXME use pydantic struct like data quality - def calculate_statistics(self) -> dict[str, float]: - number_of_variables = len(self.model.get_all_variables_current()) - number_of_observations = self.current_count - number_of_numerical = len(self.model.get_numerical_variables_current()) - number_of_categorical = len(self.model.get_categorical_variables_current()) - number_of_datetime = len(self.model.get_datetime_variables_current()) - current_columns = self.current.columns - - stats = ( - self.current.select( - [ - count(when(col(c).isNull() | isnan(c), c)).alias(c) - if t not in ("datetime", "date", "timestamp", "bool", "boolean") - else f.count(f.when(col(c).isNull(), c)).alias(c) - for c, t in self.current.dtypes - ] - ) - .withColumn(self.MISSING_CELLS, sum([f.col(c) for c in current_columns])) - .withColumn( - self.MISSING_CELLS_PERC, - ( - f.col(self.MISSING_CELLS) - / (number_of_variables * number_of_observations) - ) - * 100, - ) - .withColumn( - self.DUPLICATE_ROWS, - f.lit( - number_of_observations - - self.current.drop(self.model.timestamp.name) - .dropDuplicates( - [c for c in current_columns if c != self.model.timestamp.name] - ) - .count() - ), - ) - .withColumn( - self.DUPLICATE_ROWS_PERC, - (f.col(self.DUPLICATE_ROWS) / number_of_observations) * 100, - ) - .withColumn(self.N_VARIABLES, f.lit(number_of_variables)) - .withColumn(self.N_OBSERVATION, f.lit(number_of_observations)) - .withColumn(self.NUMERIC, f.lit(number_of_numerical)) - .withColumn(self.CATEGORICAL, f.lit(number_of_categorical)) - .withColumn(self.DATETIME, f.lit(number_of_datetime)) - .select( - *[ - self.MISSING_CELLS, - self.MISSING_CELLS_PERC, - self.DUPLICATE_ROWS, - self.DUPLICATE_ROWS_PERC, - self.N_VARIABLES, - self.N_OBSERVATION, - self.NUMERIC, - self.CATEGORICAL, - self.DATETIME, - ] - ) - .toPandas() - .to_dict(orient="records")[0] - ) - - return stats - def calculate_data_quality_numerical(self) -> List[NumericalFeatureMetrics]: numerical_features = [ numerical.name for numerical in self.model.get_numerical_features() diff --git a/spark/jobs/utils/models.py b/spark/jobs/utils/models.py index 32267319..bff567d3 100644 --- a/spark/jobs/utils/models.py +++ b/spark/jobs/utils/models.py @@ -4,8 +4,6 @@ from pydantic import BaseModel from pyspark.sql.types import ( - StructType, - StructField, StringType, DoubleType, IntegerType, @@ -97,111 +95,6 @@ def convert_types(t: str): case SupportedTypes.datetime: return TimestampType() - def to_reference_spark_schema(self): - """ - This will enforce float for target, prediction and prediction_proba - :return: the spark scheme of the reference dataset - """ - - all_features = ( - self.features + [self.target] + [self.timestamp] + self.outputs.output - ) - if self.outputs.prediction_proba: - enforce_float = [ - self.target.name, - self.outputs.prediction.name, - self.outputs.prediction_proba.name, - ] - else: - enforce_float = [self.target.name, self.outputs.prediction.name] - return StructType( - [ - StructField( - name=feature.name, - dataType=self.convert_types(feature.type), - nullable=False, - ) - if feature.name not in enforce_float - else StructField( - name=feature.name, - dataType=DoubleType(), - nullable=False, - ) - for feature in all_features - ] - ) - - # FIXME this must exclude target when we will have separate current and ground truth - def to_current_spark_schema(self): - """ - This will enforce float for target, prediction and prediction_proba - :return: the spark scheme of the current dataset (in the future without target) - """ - - all_features = ( - self.features + [self.target] + [self.timestamp] + self.outputs.output - ) - if self.outputs.prediction_proba: - enforce_float = [ - self.target.name, - self.outputs.prediction.name, - self.outputs.prediction_proba.name, - ] - else: - enforce_float = [self.target.name, self.outputs.prediction.name] - return StructType( - [ - StructField( - name=feature.name, - dataType=self.convert_types(feature.type), - nullable=False, - ) - if feature.name not in enforce_float - else StructField( - name=feature.name, - dataType=DoubleType(), - nullable=False, - ) - for feature in all_features - ] - ) - - def get_numerical_variables_reference(self) -> List[ColumnDefinition]: - all_features = ( - self.features + [self.target] + [self.timestamp] + self.outputs.output - ) - return [feature for feature in all_features if feature.is_numerical()] - - def get_categorical_variables_reference(self) -> List[ColumnDefinition]: - all_features = ( - self.features + [self.target] + [self.timestamp] + self.outputs.output - ) - return [feature for feature in all_features if feature.is_categorical()] - - def get_datetime_variables_reference(self) -> List[ColumnDefinition]: - all_features = ( - self.features + [self.target] + [self.timestamp] + self.outputs.output - ) - return [feature for feature in all_features if feature.is_datetime()] - - def get_all_variables_reference(self) -> List[ColumnDefinition]: - return self.features + [self.target] + [self.timestamp] + self.outputs.output - - def get_numerical_variables_current(self) -> List[ColumnDefinition]: - all_features = self.features + [self.timestamp] + self.outputs.output - return [feature for feature in all_features if feature.is_numerical()] - - def get_categorical_variables_current(self) -> List[ColumnDefinition]: - all_features = self.features + [self.timestamp] + self.outputs.output - return [feature for feature in all_features if feature.is_categorical()] - - def get_datetime_variables_current(self) -> List[ColumnDefinition]: - all_features = self.features + [self.timestamp] + self.outputs.output - return [feature for feature in all_features if feature.is_datetime()] - - def get_all_variables_current(self) -> List[ColumnDefinition]: - return self.features + [self.timestamp] + self.outputs.output - def get_numerical_features(self) -> List[ColumnDefinition]: return [feature for feature in self.features if feature.is_numerical()] diff --git a/spark/jobs/utils/reference.py b/spark/jobs/utils/reference.py index 5015d40e..af6c4bb9 100644 --- a/spark/jobs/utils/reference.py +++ b/spark/jobs/utils/reference.py @@ -5,7 +5,7 @@ BinaryClassificationEvaluator, MulticlassClassificationEvaluator, ) -from pyspark.sql.functions import count, when, isnan, col +from pyspark.sql.functions import col import pyspark.sql.functions as f from .data_quality import ( @@ -19,17 +19,6 @@ class ReferenceMetricsService: - # Statistics - N_VARIABLES = "n_variables" - N_OBSERVATION = "n_observations" - MISSING_CELLS = "missing_cells" - MISSING_CELLS_PERC = "missing_cells_perc" - DUPLICATE_ROWS = "duplicate_rows" - DUPLICATE_ROWS_PERC = "duplicate_rows_perc" - NUMERIC = "numeric" - CATEGORICAL = "categorical" - DATETIME = "datetime" - # Model Quality model_quality_binary_classificator = { "areaUnderROC": "area_under_roc", @@ -97,70 +86,6 @@ def __calc_mc_metrics(self) -> dict[str, float]: for (name, label) in self.model_quality_multiclass_classificator.items() } - # FIXME use pydantic struct like data quality - def calculate_statistics(self) -> dict[str, float]: - number_of_variables = len(self.model.get_all_variables_reference()) - number_of_observations = self.reference_count - number_of_numerical = len(self.model.get_numerical_variables_reference()) - number_of_categorical = len(self.model.get_categorical_variables_reference()) - number_of_datetime = len(self.model.get_datetime_variables_reference()) - reference_columns = self.reference.columns - - stats = ( - self.reference.select( - [ - count(when(isnan(c) | col(c).isNull(), c)).alias(c) - if t not in ("datetime", "date", "timestamp", "bool", "boolean") - else f.count(f.when(col(c).isNull(), c)).alias(c) - for c, t in self.reference.dtypes - ] - ) - .withColumn(self.MISSING_CELLS, sum([f.col(c) for c in reference_columns])) - .withColumn( - self.MISSING_CELLS_PERC, - ( - f.col(self.MISSING_CELLS) - / (number_of_variables * number_of_observations) - ) - * 100, - ) - .withColumn( - self.DUPLICATE_ROWS, - f.lit( - number_of_observations - - self.reference.dropDuplicates( - [c for c in reference_columns if c != self.model.timestamp.name] - ).count() - ), - ) - .withColumn( - self.DUPLICATE_ROWS_PERC, - (f.col(self.DUPLICATE_ROWS) / number_of_observations) * 100, - ) - .withColumn(self.N_VARIABLES, f.lit(number_of_variables)) - .withColumn(self.N_OBSERVATION, f.lit(number_of_observations)) - .withColumn(self.NUMERIC, f.lit(number_of_numerical)) - .withColumn(self.CATEGORICAL, f.lit(number_of_categorical)) - .withColumn(self.DATETIME, f.lit(number_of_datetime)) - .select( - *[ - self.MISSING_CELLS, - self.MISSING_CELLS_PERC, - self.DUPLICATE_ROWS, - self.DUPLICATE_ROWS_PERC, - self.N_VARIABLES, - self.N_OBSERVATION, - self.NUMERIC, - self.CATEGORICAL, - self.DATETIME, - ] - ) - .toPandas() - .to_dict(orient="records")[0] - ) - - return stats - # FIXME use pydantic struct like data quality def calculate_model_quality(self) -> dict[str, float]: metrics = self.__calc_mc_metrics() diff --git a/spark/poetry.lock b/spark/poetry.lock index 9e86d23d..2e72f38f 100644 --- a/spark/poetry.lock +++ b/spark/poetry.lock @@ -24,8 +24,10 @@ files = [ [package.dependencies] colorama = {version = "*", markers = "os_name == \"nt\""} +importlib-metadata = {version = ">=4.6", markers = "python_full_version < \"3.10.2\""} packaging = ">=19.1" pyproject_hooks = "*" +tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} [package.extras] docs = ["furo (>=2023.08.17)", "sphinx (>=7.0,<8.0)", "sphinx-argparse-cli (>=1.5)", "sphinx-autodoc-typehints (>=1.10)", "sphinx-issues (>=3.0.0)"] @@ -436,15 +438,29 @@ https = ["urllib3 (>=1.24.1)"] paramiko = ["paramiko"] pgp = ["gpg"] +[[package]] +name = "exceptiongroup" +version = "1.2.1" +description = "Backport of PEP 654 (exception groups)" +optional = false +python-versions = ">=3.7" +files = [ + {file = "exceptiongroup-1.2.1-py3-none-any.whl", hash = "sha256:5258b9ed329c5bbdd31a309f53cbfb0b155341807f6ff7606a1e801a891b29ad"}, + {file = "exceptiongroup-1.2.1.tar.gz", hash = "sha256:a4785e48b045528f5bfe627b6ad554ff32def154f42372786903b7abcfe1aa16"}, +] + +[package.extras] +test = ["pytest (>=6)"] + [[package]] name = "fastjsonschema" -version = "2.19.1" +version = "2.20.0" description = "Fastest Python implementation of JSON schema" optional = false python-versions = "*" files = [ - {file = "fastjsonschema-2.19.1-py3-none-any.whl", hash = "sha256:3672b47bc94178c9f23dbb654bf47440155d4db9df5f7bc47643315f9c405cd0"}, - {file = "fastjsonschema-2.19.1.tar.gz", hash = "sha256:e3126a94bdc4623d3de4485f8d468a12f02a67921315ddc87836d6e456dc789d"}, + {file = "fastjsonschema-2.20.0-py3-none-any.whl", hash = "sha256:5875f0b0fa7a0043a91e93a9b8f793bcbbba9691e7fd83dca95c28ba26d21f0a"}, + {file = "fastjsonschema-2.20.0.tar.gz", hash = "sha256:3d48fc5300ee96f5d116f10fe6f28d938e6008f59a6a025c2649475b87f76a23"}, ] [package.extras] @@ -452,18 +468,18 @@ devel = ["colorama", "json-spec", "jsonschema", "pylint", "pytest", "pytest-benc [[package]] name = "filelock" -version = "3.14.0" +version = "3.15.4" description = "A platform independent file lock." optional = false python-versions = ">=3.8" files = [ - {file = "filelock-3.14.0-py3-none-any.whl", hash = "sha256:43339835842f110ca7ae60f1e1c160714c5a6afd15a2873419ab185334975c0f"}, - {file = "filelock-3.14.0.tar.gz", hash = "sha256:6ea72da3be9b8c82afd3edcf99f2fffbb5076335a5ae4d03248bb5b6c3eae78a"}, + {file = "filelock-3.15.4-py3-none-any.whl", hash = "sha256:6ca1fffae96225dab4c6eaf1c4f4f28cd2568d3ec2a44e15a08520504de468e7"}, + {file = "filelock-3.15.4.tar.gz", hash = "sha256:2207938cbc1844345cb01a5a95524dae30f0ce089eba5b00378295a17e3e90cb"}, ] [package.extras] docs = ["furo (>=2023.9.10)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.25.2)"] -testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8.0.1)", "pytest (>=7.4.3)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)"] +testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8.0.1)", "pytest (>=7.4.3)", "pytest-asyncio (>=0.21)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)", "virtualenv (>=20.26.2)"] typing = ["typing-extensions (>=4.8)"] [[package]] @@ -479,22 +495,22 @@ files = [ [[package]] name = "importlib-metadata" -version = "7.1.0" +version = "8.0.0" description = "Read metadata from Python packages" optional = false python-versions = ">=3.8" files = [ - {file = "importlib_metadata-7.1.0-py3-none-any.whl", hash = "sha256:30962b96c0c223483ed6cc7280e7f0199feb01a0e40cfae4d4450fc6fab1f570"}, - {file = "importlib_metadata-7.1.0.tar.gz", hash = "sha256:b78938b926ee8d5f020fc4772d487045805a55ddbad2ecf21c6d60938dc7fcd2"}, + {file = "importlib_metadata-8.0.0-py3-none-any.whl", hash = "sha256:15584cf2b1bf449d98ff8a6ff1abef57bf20f3ac6454f431736cd3e660921b2f"}, + {file = "importlib_metadata-8.0.0.tar.gz", hash = "sha256:188bd24e4c346d3f0a933f275c2fec67050326a856b9a359881d7c2a697e8812"}, ] [package.dependencies] zipp = ">=0.5" [package.extras] -docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] +doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] perf = ["ipython"] -testing = ["flufl.flake8", "importlib-resources (>=1.3)", "jaraco.test (>=5.4)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-perf (>=0.9.2)", "pytest-ruff (>=0.2.1)"] +test = ["flufl.flake8", "importlib-resources (>=1.3)", "jaraco.test (>=5.4)", "packaging", "pyfakefs", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-perf (>=0.9.2)", "pytest-ruff (>=0.2.1)"] [[package]] name = "iniconfig" @@ -576,13 +592,13 @@ testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-ena [[package]] name = "more-itertools" -version = "10.2.0" +version = "10.3.0" description = "More routines for operating on iterables, beyond itertools" optional = false python-versions = ">=3.8" files = [ - {file = "more-itertools-10.2.0.tar.gz", hash = "sha256:8fccb480c43d3e99a00087634c06dd02b0d50fbf088b380de5a41a015ec239e1"}, - {file = "more_itertools-10.2.0-py3-none-any.whl", hash = "sha256:686b06abe565edfab151cb8fd385a05651e1fdf8f0a14191e4439283421f8684"}, + {file = "more-itertools-10.3.0.tar.gz", hash = "sha256:e5d93ef411224fbcef366a6e8ddc4c5781bc6359d43412a65dd5964e46111463"}, + {file = "more_itertools-10.3.0-py3-none-any.whl", hash = "sha256:ea6a02e24a9161e51faad17a8782b92a0df82c12c1c8886fec7f0c3fa1a1b320"}, ] [[package]] @@ -766,13 +782,13 @@ files = [ [[package]] name = "packaging" -version = "24.0" +version = "24.1" description = "Core utilities for Python packages" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "packaging-24.0-py3-none-any.whl", hash = "sha256:2ddfb553fdf02fb784c234c7ba6ccc288296ceabec964ad2eae3777778130bc5"}, - {file = "packaging-24.0.tar.gz", hash = "sha256:eb82c5e3e56209074766e6885bb04b8c38a0c015d0a30036ebe7ece34c9989e9"}, + {file = "packaging-24.1-py3-none-any.whl", hash = "sha256:5b8f2217dbdbd2f7f384c41c628544e6d52f2d0f53c6d0c3ea61aa5d1d7ff124"}, + {file = "packaging-24.1.tar.gz", hash = "sha256:026ed72c8ed3fcce5bf8950572258698927fd1dbda10a5e981cdf0ac37f4f002"}, ] [[package]] @@ -783,7 +799,6 @@ optional = false python-versions = ">=3.9" files = [ {file = "pandas-2.2.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:90c6fca2acf139569e74e8781709dccb6fe25940488755716d1d354d6bc58bce"}, - {file = "pandas-2.2.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c7adfc142dac335d8c1e0dcbd37eb8617eac386596eb9e1a1b77791cf2498238"}, {file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4abfe0be0d7221be4f12552995e58723c7422c80a659da13ca382697de830c08"}, {file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8635c16bf3d99040fdf3ca3db669a7250ddf49c55dc4aa8fe0ae0fa8d6dcc1f0"}, {file = "pandas-2.2.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:40ae1dffb3967a52203105a077415a86044a2bea011b5f321c6aa64b379a3f51"}, @@ -797,14 +812,12 @@ files = [ {file = "pandas-2.2.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:0cace394b6ea70c01ca1595f839cf193df35d1575986e484ad35c4aeae7266c1"}, {file = "pandas-2.2.2-cp311-cp311-win_amd64.whl", hash = "sha256:873d13d177501a28b2756375d59816c365e42ed8417b41665f346289adc68d24"}, {file = "pandas-2.2.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:9dfde2a0ddef507a631dc9dc4af6a9489d5e2e740e226ad426a05cabfbd7c8ef"}, - {file = "pandas-2.2.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:e9b79011ff7a0f4b1d6da6a61aa1aa604fb312d6647de5bad20013682d1429ce"}, {file = "pandas-2.2.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1cb51fe389360f3b5a4d57dbd2848a5f033350336ca3b340d1c53a1fad33bcad"}, {file = "pandas-2.2.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eee3a87076c0756de40b05c5e9a6069c035ba43e8dd71c379e68cab2c20f16ad"}, {file = "pandas-2.2.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:3e374f59e440d4ab45ca2fffde54b81ac3834cf5ae2cdfa69c90bc03bde04d76"}, {file = "pandas-2.2.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:43498c0bdb43d55cb162cdc8c06fac328ccb5d2eabe3cadeb3529ae6f0517c32"}, {file = "pandas-2.2.2-cp312-cp312-win_amd64.whl", hash = "sha256:d187d355ecec3629624fccb01d104da7d7f391db0311145817525281e2804d23"}, {file = "pandas-2.2.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0ca6377b8fca51815f382bd0b697a0814c8bda55115678cbc94c30aacbb6eff2"}, - {file = "pandas-2.2.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9057e6aa78a584bc93a13f0a9bf7e753a5e9770a30b4d758b8d5f2a62a9433cd"}, {file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:001910ad31abc7bf06f49dcc903755d2f7f3a9186c0c040b827e522e9cef0863"}, {file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:66b479b0bd07204e37583c191535505410daa8df638fd8e75ae1b383851fe921"}, {file = "pandas-2.2.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:a77e9d1c386196879aa5eb712e77461aaee433e54c68cf253053a73b7e49c33a"}, @@ -815,6 +828,7 @@ files = [ [package.dependencies] numpy = [ + {version = ">=1.22.4", markers = "python_version < \"3.11\""}, {version = ">=1.23.2", markers = "python_version == \"3.11\""}, {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, ] @@ -863,13 +877,13 @@ ptyprocess = ">=0.5" [[package]] name = "pkginfo" -version = "1.11.0" +version = "1.11.1" description = "Query metadata from sdists / bdists / installed packages." optional = false python-versions = ">=3.8" files = [ - {file = "pkginfo-1.11.0-py3-none-any.whl", hash = "sha256:6d4998d1cd42c297af72cc0eab5f5bab1d356fb8a55b828fa914173f8bc1ba05"}, - {file = "pkginfo-1.11.0.tar.gz", hash = "sha256:dba885aa82e31e80d615119874384923f4e011c2a39b0c4b7104359e36cb7087"}, + {file = "pkginfo-1.11.1-py3-none-any.whl", hash = "sha256:bfa76a714fdfc18a045fcd684dbfc3816b603d9d075febef17cb6582bea29573"}, + {file = "pkginfo-1.11.1.tar.gz", hash = "sha256:2e0dca1cf4c8e39644eed32408ea9966ee15e0d324c62ba899a393b3c6b467aa"}, ] [package.extras] @@ -936,6 +950,7 @@ pyproject-hooks = ">=1.0.0,<2.0.0" requests = ">=2.26,<3.0" requests-toolbelt = ">=1.0.0,<2.0.0" shellingham = ">=1.5,<2.0" +tomli = {version = ">=2.0.1,<3.0.0", markers = "python_version < \"3.11\""} tomlkit = ">=0.11.4,<1.0.0" trove-classifiers = ">=2022.5.19" virtualenv = ">=20.23.0,<21.0.0" @@ -978,8 +993,6 @@ files = [ {file = "psycopg2-2.9.9-cp310-cp310-win_amd64.whl", hash = "sha256:426f9f29bde126913a20a96ff8ce7d73fd8a216cfb323b1f04da402d452853c3"}, {file = "psycopg2-2.9.9-cp311-cp311-win32.whl", hash = "sha256:ade01303ccf7ae12c356a5e10911c9e1c51136003a9a1d92f7aa9d010fb98372"}, {file = "psycopg2-2.9.9-cp311-cp311-win_amd64.whl", hash = "sha256:121081ea2e76729acfb0673ff33755e8703d45e926e416cb59bae3a86c6a4981"}, - {file = "psycopg2-2.9.9-cp312-cp312-win32.whl", hash = "sha256:d735786acc7dd25815e89cc4ad529a43af779db2e25aa7c626de864127e5a024"}, - {file = "psycopg2-2.9.9-cp312-cp312-win_amd64.whl", hash = "sha256:a7653d00b732afb6fc597e29c50ad28087dcb4fbfb28e86092277a559ae4e693"}, {file = "psycopg2-2.9.9-cp37-cp37m-win32.whl", hash = "sha256:5e0d98cade4f0e0304d7d6f25bbfbc5bd186e07b38eac65379309c4ca3193efa"}, {file = "psycopg2-2.9.9-cp37-cp37m-win_amd64.whl", hash = "sha256:7e2dacf8b009a1c1e843b5213a87f7c544b2b042476ed7755be813eaf4e8347a"}, {file = "psycopg2-2.9.9-cp38-cp38-win32.whl", hash = "sha256:ff432630e510709564c01dafdbe996cb552e0b9f3f065eb89bdce5bd31fabf4c"}, @@ -1024,13 +1037,13 @@ files = [ [[package]] name = "pydantic" -version = "2.7.3" +version = "2.7.4" description = "Data validation using Python type hints" optional = false python-versions = ">=3.8" files = [ - {file = "pydantic-2.7.3-py3-none-any.whl", hash = "sha256:ea91b002777bf643bb20dd717c028ec43216b24a6001a280f83877fd2655d0b4"}, - {file = "pydantic-2.7.3.tar.gz", hash = "sha256:c46c76a40bb1296728d7a8b99aa73dd70a48c3510111ff290034f860c99c419e"}, + {file = "pydantic-2.7.4-py3-none-any.whl", hash = "sha256:ee8538d41ccb9c0a9ad3e0e5f07bf15ed8015b481ced539a1759d8cc89ae90d0"}, + {file = "pydantic-2.7.4.tar.gz", hash = "sha256:0c84efd9548d545f63ac0060c1e4d39bb9b14db8b3c0652338aecc07b5adec52"}, ] [package.dependencies] @@ -1176,9 +1189,11 @@ files = [ [package.dependencies] colorama = {version = "*", markers = "sys_platform == \"win32\""} +exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""} iniconfig = "*" packaging = "*" pluggy = ">=1.5,<2.0" +tomli = {version = ">=1", markers = "python_version < \"3.11\""} [package.extras] dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] @@ -1361,71 +1376,71 @@ requests = ">=2.0.1,<3.0.0" [[package]] name = "ruff" -version = "0.4.7" +version = "0.4.10" description = "An extremely fast Python linter and code formatter, written in Rust." optional = false python-versions = ">=3.7" files = [ - {file = "ruff-0.4.7-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:e089371c67892a73b6bb1525608e89a2aca1b77b5440acf7a71dda5dac958f9e"}, - {file = "ruff-0.4.7-py3-none-macosx_11_0_arm64.whl", hash = "sha256:10f973d521d910e5f9c72ab27e409e839089f955be8a4c8826601a6323a89753"}, - {file = "ruff-0.4.7-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:59c3d110970001dfa494bcd95478e62286c751126dfb15c3c46e7915fc49694f"}, - {file = "ruff-0.4.7-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:fa9773c6c00f4958f73b317bc0fd125295110c3776089f6ef318f4b775f0abe4"}, - {file = "ruff-0.4.7-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:07fc80bbb61e42b3b23b10fda6a2a0f5a067f810180a3760c5ef1b456c21b9db"}, - {file = "ruff-0.4.7-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:fa4dafe3fe66d90e2e2b63fa1591dd6e3f090ca2128daa0be33db894e6c18648"}, - {file = "ruff-0.4.7-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a7c0083febdec17571455903b184a10026603a1de078428ba155e7ce9358c5f6"}, - {file = "ruff-0.4.7-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ad1b20e66a44057c326168437d680a2166c177c939346b19c0d6b08a62a37589"}, - {file = "ruff-0.4.7-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cbf5d818553add7511c38b05532d94a407f499d1a76ebb0cad0374e32bc67202"}, - {file = "ruff-0.4.7-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:50e9651578b629baec3d1513b2534de0ac7ed7753e1382272b8d609997e27e83"}, - {file = "ruff-0.4.7-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:8874a9df7766cb956b218a0a239e0a5d23d9e843e4da1e113ae1d27ee420877a"}, - {file = "ruff-0.4.7-py3-none-musllinux_1_2_i686.whl", hash = "sha256:b9de9a6e49f7d529decd09381c0860c3f82fa0b0ea00ea78409b785d2308a567"}, - {file = "ruff-0.4.7-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:13a1768b0691619822ae6d446132dbdfd568b700ecd3652b20d4e8bc1e498f78"}, - {file = "ruff-0.4.7-py3-none-win32.whl", hash = "sha256:769e5a51df61e07e887b81e6f039e7ed3573316ab7dd9f635c5afaa310e4030e"}, - {file = "ruff-0.4.7-py3-none-win_amd64.whl", hash = "sha256:9e3ab684ad403a9ed1226894c32c3ab9c2e0718440f6f50c7c5829932bc9e054"}, - {file = "ruff-0.4.7-py3-none-win_arm64.whl", hash = "sha256:10f2204b9a613988e3484194c2c9e96a22079206b22b787605c255f130db5ed7"}, - {file = "ruff-0.4.7.tar.gz", hash = "sha256:2331d2b051dc77a289a653fcc6a42cce357087c5975738157cd966590b18b5e1"}, + {file = "ruff-0.4.10-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:5c2c4d0859305ac5a16310eec40e4e9a9dec5dcdfbe92697acd99624e8638dac"}, + {file = "ruff-0.4.10-py3-none-macosx_11_0_arm64.whl", hash = "sha256:a79489607d1495685cdd911a323a35871abfb7a95d4f98fc6f85e799227ac46e"}, + {file = "ruff-0.4.10-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b1dd1681dfa90a41b8376a61af05cc4dc5ff32c8f14f5fe20dba9ff5deb80cd6"}, + {file = "ruff-0.4.10-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c75c53bb79d71310dc79fb69eb4902fba804a81f374bc86a9b117a8d077a1784"}, + {file = "ruff-0.4.10-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:18238c80ee3d9100d3535d8eb15a59c4a0753b45cc55f8bf38f38d6a597b9739"}, + {file = "ruff-0.4.10-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:d8f71885bce242da344989cae08e263de29752f094233f932d4f5cfb4ef36a81"}, + {file = "ruff-0.4.10-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:330421543bd3222cdfec481e8ff3460e8702ed1e58b494cf9d9e4bf90db52b9d"}, + {file = "ruff-0.4.10-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9e9b6fb3a37b772628415b00c4fc892f97954275394ed611056a4b8a2631365e"}, + {file = "ruff-0.4.10-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0f54c481b39a762d48f64d97351048e842861c6662d63ec599f67d515cb417f6"}, + {file = "ruff-0.4.10-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:67fe086b433b965c22de0b4259ddfe6fa541c95bf418499bedb9ad5fb8d1c631"}, + {file = "ruff-0.4.10-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:acfaaab59543382085f9eb51f8e87bac26bf96b164839955f244d07125a982ef"}, + {file = "ruff-0.4.10-py3-none-musllinux_1_2_i686.whl", hash = "sha256:3cea07079962b2941244191569cf3a05541477286f5cafea638cd3aa94b56815"}, + {file = "ruff-0.4.10-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:338a64ef0748f8c3a80d7f05785930f7965d71ca260904a9321d13be24b79695"}, + {file = "ruff-0.4.10-py3-none-win32.whl", hash = "sha256:ffe3cd2f89cb54561c62e5fa20e8f182c0a444934bf430515a4b422f1ab7b7ca"}, + {file = "ruff-0.4.10-py3-none-win_amd64.whl", hash = "sha256:67f67cef43c55ffc8cc59e8e0b97e9e60b4837c8f21e8ab5ffd5d66e196e25f7"}, + {file = "ruff-0.4.10-py3-none-win_arm64.whl", hash = "sha256:dd1fcee327c20addac7916ca4e2653fbbf2e8388d8a6477ce5b4e986b68ae6c0"}, + {file = "ruff-0.4.10.tar.gz", hash = "sha256:3aa4f2bc388a30d346c56524f7cacca85945ba124945fe489952aadb6b5cd804"}, ] [[package]] name = "scipy" -version = "1.13.1" +version = "1.14.0" description = "Fundamental algorithms for scientific computing in Python" optional = false -python-versions = ">=3.9" -files = [ - {file = "scipy-1.13.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:20335853b85e9a49ff7572ab453794298bcf0354d8068c5f6775a0eabf350aca"}, - {file = "scipy-1.13.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:d605e9c23906d1994f55ace80e0125c587f96c020037ea6aa98d01b4bd2e222f"}, - {file = "scipy-1.13.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cfa31f1def5c819b19ecc3a8b52d28ffdcc7ed52bb20c9a7589669dd3c250989"}, - {file = "scipy-1.13.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f26264b282b9da0952a024ae34710c2aff7d27480ee91a2e82b7b7073c24722f"}, - {file = "scipy-1.13.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:eccfa1906eacc02de42d70ef4aecea45415f5be17e72b61bafcfd329bdc52e94"}, - {file = "scipy-1.13.1-cp310-cp310-win_amd64.whl", hash = "sha256:2831f0dc9c5ea9edd6e51e6e769b655f08ec6db6e2e10f86ef39bd32eb11da54"}, - {file = "scipy-1.13.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:27e52b09c0d3a1d5b63e1105f24177e544a222b43611aaf5bc44d4a0979e32f9"}, - {file = "scipy-1.13.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:54f430b00f0133e2224c3ba42b805bfd0086fe488835effa33fa291561932326"}, - {file = "scipy-1.13.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e89369d27f9e7b0884ae559a3a956e77c02114cc60a6058b4e5011572eea9299"}, - {file = "scipy-1.13.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a78b4b3345f1b6f68a763c6e25c0c9a23a9fd0f39f5f3d200efe8feda560a5fa"}, - {file = "scipy-1.13.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:45484bee6d65633752c490404513b9ef02475b4284c4cfab0ef946def50b3f59"}, - {file = "scipy-1.13.1-cp311-cp311-win_amd64.whl", hash = "sha256:5713f62f781eebd8d597eb3f88b8bf9274e79eeabf63afb4a737abc6c84ad37b"}, - {file = "scipy-1.13.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:5d72782f39716b2b3509cd7c33cdc08c96f2f4d2b06d51e52fb45a19ca0c86a1"}, - {file = "scipy-1.13.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:017367484ce5498445aade74b1d5ab377acdc65e27095155e448c88497755a5d"}, - {file = "scipy-1.13.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:949ae67db5fa78a86e8fa644b9a6b07252f449dcf74247108c50e1d20d2b4627"}, - {file = "scipy-1.13.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:de3ade0e53bc1f21358aa74ff4830235d716211d7d077e340c7349bc3542e884"}, - {file = "scipy-1.13.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:2ac65fb503dad64218c228e2dc2d0a0193f7904747db43014645ae139c8fad16"}, - {file = "scipy-1.13.1-cp312-cp312-win_amd64.whl", hash = "sha256:cdd7dacfb95fea358916410ec61bbc20440f7860333aee6d882bb8046264e949"}, - {file = "scipy-1.13.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:436bbb42a94a8aeef855d755ce5a465479c721e9d684de76bf61a62e7c2b81d5"}, - {file = "scipy-1.13.1-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:8335549ebbca860c52bf3d02f80784e91a004b71b059e3eea9678ba994796a24"}, - {file = "scipy-1.13.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d533654b7d221a6a97304ab63c41c96473ff04459e404b83275b60aa8f4b7004"}, - {file = "scipy-1.13.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:637e98dcf185ba7f8e663e122ebf908c4702420477ae52a04f9908707456ba4d"}, - {file = "scipy-1.13.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a014c2b3697bde71724244f63de2476925596c24285c7a637364761f8710891c"}, - {file = "scipy-1.13.1-cp39-cp39-win_amd64.whl", hash = "sha256:392e4ec766654852c25ebad4f64e4e584cf19820b980bc04960bca0b0cd6eaa2"}, - {file = "scipy-1.13.1.tar.gz", hash = "sha256:095a87a0312b08dfd6a6155cbbd310a8c51800fc931b8c0b84003014b874ed3c"}, +python-versions = ">=3.10" +files = [ + {file = "scipy-1.14.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:7e911933d54ead4d557c02402710c2396529540b81dd554fc1ba270eb7308484"}, + {file = "scipy-1.14.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:687af0a35462402dd851726295c1a5ae5f987bd6e9026f52e9505994e2f84ef6"}, + {file = "scipy-1.14.0-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:07e179dc0205a50721022344fb85074f772eadbda1e1b3eecdc483f8033709b7"}, + {file = "scipy-1.14.0-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:6a9c9a9b226d9a21e0a208bdb024c3982932e43811b62d202aaf1bb59af264b1"}, + {file = "scipy-1.14.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:076c27284c768b84a45dcf2e914d4000aac537da74236a0d45d82c6fa4b7b3c0"}, + {file = "scipy-1.14.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:42470ea0195336df319741e230626b6225a740fd9dce9642ca13e98f667047c0"}, + {file = "scipy-1.14.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:176c6f0d0470a32f1b2efaf40c3d37a24876cebf447498a4cefb947a79c21e9d"}, + {file = "scipy-1.14.0-cp310-cp310-win_amd64.whl", hash = "sha256:ad36af9626d27a4326c8e884917b7ec321d8a1841cd6dacc67d2a9e90c2f0359"}, + {file = "scipy-1.14.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6d056a8709ccda6cf36cdd2eac597d13bc03dba38360f418560a93050c76a16e"}, + {file = "scipy-1.14.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:f0a50da861a7ec4573b7c716b2ebdcdf142b66b756a0d392c236ae568b3a93fb"}, + {file = "scipy-1.14.0-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:94c164a9e2498e68308e6e148646e486d979f7fcdb8b4cf34b5441894bdb9caf"}, + {file = "scipy-1.14.0-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:a7d46c3e0aea5c064e734c3eac5cf9eb1f8c4ceee756262f2c7327c4c2691c86"}, + {file = "scipy-1.14.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9eee2989868e274aae26125345584254d97c56194c072ed96cb433f32f692ed8"}, + {file = "scipy-1.14.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9e3154691b9f7ed73778d746da2df67a19d046a6c8087c8b385bc4cdb2cfca74"}, + {file = "scipy-1.14.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:c40003d880f39c11c1edbae8144e3813904b10514cd3d3d00c277ae996488cdb"}, + {file = "scipy-1.14.0-cp311-cp311-win_amd64.whl", hash = "sha256:5b083c8940028bb7e0b4172acafda6df762da1927b9091f9611b0bcd8676f2bc"}, + {file = "scipy-1.14.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:bff2438ea1330e06e53c424893ec0072640dac00f29c6a43a575cbae4c99b2b9"}, + {file = "scipy-1.14.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:bbc0471b5f22c11c389075d091d3885693fd3f5e9a54ce051b46308bc787e5d4"}, + {file = "scipy-1.14.0-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:64b2ff514a98cf2bb734a9f90d32dc89dc6ad4a4a36a312cd0d6327170339eb0"}, + {file = "scipy-1.14.0-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:7d3da42fbbbb860211a811782504f38ae7aaec9de8764a9bef6b262de7a2b50f"}, + {file = "scipy-1.14.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d91db2c41dd6c20646af280355d41dfa1ec7eead235642178bd57635a3f82209"}, + {file = "scipy-1.14.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a01cc03bcdc777c9da3cfdcc74b5a75caffb48a6c39c8450a9a05f82c4250a14"}, + {file = "scipy-1.14.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:65df4da3c12a2bb9ad52b86b4dcf46813e869afb006e58be0f516bc370165159"}, + {file = "scipy-1.14.0-cp312-cp312-win_amd64.whl", hash = "sha256:4c4161597c75043f7154238ef419c29a64ac4a7c889d588ea77690ac4d0d9b20"}, + {file = "scipy-1.14.0.tar.gz", hash = "sha256:b5923f48cb840380f9854339176ef21763118a7300a88203ccd0bdd26e58527b"}, ] [package.dependencies] -numpy = ">=1.22.4,<2.3" +numpy = ">=1.23.5,<2.3" [package.extras] -dev = ["cython-lint (>=0.12.2)", "doit (>=0.36.0)", "mypy", "pycodestyle", "pydevtool", "rich-click", "ruff", "types-psutil", "typing_extensions"] -doc = ["jupyterlite-pyodide-kernel", "jupyterlite-sphinx (>=0.12.0)", "jupytext", "matplotlib (>=3.5)", "myst-nb", "numpydoc", "pooch", "pydata-sphinx-theme (>=0.15.2)", "sphinx (>=5.0.0)", "sphinx-design (>=0.4.0)"] -test = ["array-api-strict", "asv", "gmpy2", "hypothesis (>=6.30)", "mpmath", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"] +dev = ["cython-lint (>=0.12.2)", "doit (>=0.36.0)", "mypy (==1.10.0)", "pycodestyle", "pydevtool", "rich-click", "ruff (>=0.0.292)", "types-psutil", "typing_extensions"] +doc = ["jupyterlite-pyodide-kernel", "jupyterlite-sphinx (>=0.13.1)", "jupytext", "matplotlib (>=3.5)", "myst-nb", "numpydoc", "pooch", "pydata-sphinx-theme (>=0.15.2)", "sphinx (>=5.0.0)", "sphinx-design (>=0.4.0)"] +test = ["Cython", "array-api-strict", "asv", "gmpy2", "hypothesis (>=6.30)", "meson", "mpmath", "ninja", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"] [[package]] name = "secretstorage" @@ -1464,6 +1479,17 @@ files = [ {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"}, ] +[[package]] +name = "tomli" +version = "2.0.1" +description = "A lil' TOML parser" +optional = false +python-versions = ">=3.7" +files = [ + {file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"}, + {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"}, +] + [[package]] name = "tomlkit" version = "0.12.5" @@ -1488,13 +1514,13 @@ files = [ [[package]] name = "typing-extensions" -version = "4.12.1" +version = "4.12.2" description = "Backported and Experimental Type Hints for Python 3.8+" optional = false python-versions = ">=3.8" files = [ - {file = "typing_extensions-4.12.1-py3-none-any.whl", hash = "sha256:6024b58b69089e5a89c347397254e35f1bf02a907728ec7fee9bf0fe837d203a"}, - {file = "typing_extensions-4.12.1.tar.gz", hash = "sha256:915f5e35ff76f56588223f15fdd5938f9a1cf9195c0de25130c627e4d597f6d1"}, + {file = "typing_extensions-4.12.2-py3-none-any.whl", hash = "sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d"}, + {file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"}, ] [[package]] @@ -1510,13 +1536,13 @@ files = [ [[package]] name = "urllib3" -version = "2.2.1" +version = "2.2.2" description = "HTTP library with thread-safe connection pooling, file post, and more." optional = false python-versions = ">=3.8" files = [ - {file = "urllib3-2.2.1-py3-none-any.whl", hash = "sha256:450b20ec296a467077128bff42b73080516e71b56ff59a60a02bef2232c4fa9d"}, - {file = "urllib3-2.2.1.tar.gz", hash = "sha256:d0570876c61ab9e520d776c38acbbb5b05a776d3f9ff98a5c8fd5162a444cf19"}, + {file = "urllib3-2.2.2-py3-none-any.whl", hash = "sha256:a448b2f64d686155468037e1ace9f2d2199776e17f0a46610480d311f73e3472"}, + {file = "urllib3-2.2.2.tar.gz", hash = "sha256:dd505485549a7a552833da5e6063639d0d177c04f23bc3864e41e5dc5f612168"}, ] [package.extras] @@ -1527,13 +1553,13 @@ zstd = ["zstandard (>=0.18.0)"] [[package]] name = "virtualenv" -version = "20.26.2" +version = "20.26.3" description = "Virtual Python Environment builder" optional = false python-versions = ">=3.7" files = [ - {file = "virtualenv-20.26.2-py3-none-any.whl", hash = "sha256:a624db5e94f01ad993d476b9ee5346fdf7b9de43ccaee0e0197012dc838a0e9b"}, - {file = "virtualenv-20.26.2.tar.gz", hash = "sha256:82bf0f4eebbb78d36ddaee0283d43fe5736b53880b8a8cdcd37390a07ac3741c"}, + {file = "virtualenv-20.26.3-py3-none-any.whl", hash = "sha256:8cc4a31139e796e9a7de2cd5cf2489de1217193116a8fd42328f1bd65f434589"}, + {file = "virtualenv-20.26.3.tar.gz", hash = "sha256:4c43a2a236279d9ea36a0d76f98d84bd6ca94ac4e0f4a3b9d46d05e10fea542a"}, ] [package.dependencies] @@ -1635,5 +1661,5 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [metadata] lock-version = "2.0" -python-versions = "^3.11" -content-hash = "0998fb600e56ce267b916e0390f9f84850ffe5dada9440da44b7408138d2dd7e" +python-versions = "^3.10" +content-hash = "2e8d1cbec5bd39ac59c2cfbc1e185f6937c4a48593d3d7b44fa84e09b2d99517" diff --git a/spark/pyproject.toml b/spark/pyproject.toml index 2e9da91f..32c99a5f 100644 --- a/spark/pyproject.toml +++ b/spark/pyproject.toml @@ -6,9 +6,10 @@ version = "0.8.1" description = "Spark jobs collection for Radicalbit AI Monitoring Platform" authors = ["Radicalbit "] readme = "README.md" +package-mode = false [tool.poetry.dependencies] -python = "^3.11" +python = "^3.10" pyspark = "^3.5.1" pydantic = "^2.7.3" numpy = "^1.26.4" @@ -27,3 +28,6 @@ deepdiff = "^7.0.1" [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" + +[tool.pytest.ini_options] +pythonpath = "jobs" diff --git a/spark/tests/current_test.py b/spark/tests/binary_current_test.py similarity index 92% rename from spark/tests/current_test.py rename to spark/tests/binary_current_test.py index b7db313c..87121d6a 100644 --- a/spark/tests/current_test.py +++ b/spark/tests/binary_current_test.py @@ -6,6 +6,9 @@ import pytest from pyspark.sql import SparkSession +from jobs.metrics.statistics import calculate_statistics_current +from jobs.models.current_dataset import CurrentDataset +from jobs.models.reference_dataset import ReferenceDataset from jobs.utils.current import CurrentMetricsService from jobs.utils.models import ( ModelOut, @@ -16,7 +19,6 @@ SupportedTypes, Granularity, ) -from jobs.utils.spark import apply_schema_to_dataframe from tests.utils.pytest_utils import my_approx test_resource_path = Path(__file__).resolve().parent / "resources" @@ -197,34 +199,20 @@ def test_calculation(spark_fixture, dataset): updated_at=str(datetime.datetime.now()), ) - current_dataset, reference_dataset = dataset - current_dataset = apply_schema_to_dataframe( - current_dataset, model.to_current_spark_schema() - ) - current_dataset = current_dataset.select( - *[ - c - for c in model.to_current_spark_schema().names - if c in current_dataset.columns - ] - ) - reference_dataset = apply_schema_to_dataframe( - reference_dataset, model.to_reference_spark_schema() - ) - reference_dataset = reference_dataset.select( - *[ - c - for c in model.to_reference_spark_schema().names - if c in current_dataset.columns - ] + raw_current_dataset, raw_reference_dataset = dataset + current_dataset = CurrentDataset(model=model, raw_dataframe=raw_current_dataset) + reference_dataset = ReferenceDataset( + model=model, raw_dataframe=raw_reference_dataset ) + metrics_service = CurrentMetricsService( spark_session=spark_fixture, - current=current_dataset, - reference=reference_dataset, + current=current_dataset.current, + reference=reference_dataset.reference, model=model, ) - stats = metrics_service.calculate_statistics() + + stats = calculate_statistics_current(current_dataset) data_quality = metrics_service.calculate_data_quality() model_quality = metrics_service.calculate_model_quality_with_group_by_timestamp() @@ -235,10 +223,10 @@ def test_calculation(spark_fixture, dataset): "duplicate_rows": 3, "duplicate_rows_perc": 30.0, "missing_cells": 3, - "missing_cells_perc": 4.285714285714286, + "missing_cells_perc": 3.75, "n_observations": 10, - "n_variables": 7, - "numeric": 4, + "n_variables": 8, + "numeric": 5, } ) @@ -487,34 +475,20 @@ def test_calculation_current_joined(spark_fixture, current_joined): updated_at=str(datetime.datetime.now()), ) - current_dataset, reference_dataset = current_joined - current_dataset = apply_schema_to_dataframe( - current_dataset, model.to_current_spark_schema() - ) - current_dataset = current_dataset.select( - *[ - c - for c in model.to_current_spark_schema().names - if c in current_dataset.columns - ] - ) - reference_dataset = apply_schema_to_dataframe( - reference_dataset, model.to_reference_spark_schema() - ) - reference_dataset = reference_dataset.select( - *[ - c - for c in model.to_reference_spark_schema().names - if c in current_dataset.columns - ] + raw_current_dataset, raw_reference_dataset = current_joined + current_dataset = CurrentDataset(model=model, raw_dataframe=raw_current_dataset) + reference_dataset = ReferenceDataset( + model=model, raw_dataframe=raw_reference_dataset ) + metrics_service = CurrentMetricsService( spark_session=spark_fixture, - current=current_dataset, - reference=reference_dataset, + current=current_dataset.current, + reference=reference_dataset.reference, model=model, ) - stats = metrics_service.calculate_statistics() + + stats = calculate_statistics_current(current_dataset) data_quality = metrics_service.calculate_data_quality() assert stats == my_approx( @@ -526,8 +500,8 @@ def test_calculation_current_joined(spark_fixture, current_joined): "missing_cells": 0, "missing_cells_perc": 0.0, "n_observations": 238, - "n_variables": 14, - "numeric": 12, + "n_variables": 15, + "numeric": 13, } ) @@ -1084,34 +1058,20 @@ def test_calculation_complete(spark_fixture, complete_dataset): updated_at=str(datetime.datetime.now()), ) - current_dataset, reference_dataset = complete_dataset - current_dataset = apply_schema_to_dataframe( - current_dataset, model.to_current_spark_schema() - ) - current_dataset = current_dataset.select( - *[ - c - for c in model.to_current_spark_schema().names - if c in current_dataset.columns - ] - ) - reference_dataset = apply_schema_to_dataframe( - reference_dataset, model.to_reference_spark_schema() - ) - reference_dataset = reference_dataset.select( - *[ - c - for c in model.to_reference_spark_schema().names - if c in current_dataset.columns - ] + raw_current_dataset, raw_reference_dataset = complete_dataset + current_dataset = CurrentDataset(model=model, raw_dataframe=raw_current_dataset) + reference_dataset = ReferenceDataset( + model=model, raw_dataframe=raw_reference_dataset ) + metrics_service = CurrentMetricsService( spark_session=spark_fixture, - current=current_dataset, - reference=reference_dataset, + current=current_dataset.current, + reference=reference_dataset.reference, model=model, ) - stats = metrics_service.calculate_statistics() + + stats = calculate_statistics_current(current_dataset) data_quality = metrics_service.calculate_data_quality() assert stats == my_approx( @@ -1120,10 +1080,10 @@ def test_calculation_complete(spark_fixture, complete_dataset): "missing_cells_perc": 0.0, "duplicate_rows": 0, "duplicate_rows_perc": 0.0, - "n_variables": 7, + "n_variables": 8, "n_observations": 7, "numeric": 4, - "categorical": 2, + "categorical": 3, "datetime": 1, }, ) @@ -1285,34 +1245,20 @@ def test_calculation_easy_dataset(spark_fixture, easy_dataset): updated_at=str(datetime.datetime.now()), ) - current_dataset, reference_dataset = easy_dataset - current_dataset = apply_schema_to_dataframe( - current_dataset, model.to_current_spark_schema() - ) - current_dataset = current_dataset.select( - *[ - c - for c in model.to_current_spark_schema().names - if c in current_dataset.columns - ] - ) - reference_dataset = apply_schema_to_dataframe( - reference_dataset, model.to_reference_spark_schema() - ) - reference_dataset = reference_dataset.select( - *[ - c - for c in model.to_reference_spark_schema().names - if c in current_dataset.columns - ] + raw_current_dataset, raw_reference_dataset = easy_dataset + current_dataset = CurrentDataset(model=model, raw_dataframe=raw_current_dataset) + reference_dataset = ReferenceDataset( + model=model, raw_dataframe=raw_reference_dataset ) + metrics_service = CurrentMetricsService( spark_session=spark_fixture, - current=current_dataset, - reference=reference_dataset, + current=current_dataset.current, + reference=reference_dataset.reference, model=model, ) - stats = metrics_service.calculate_statistics() + + stats = calculate_statistics_current(current_dataset) data_quality = metrics_service.calculate_data_quality() assert stats == my_approx( @@ -1321,9 +1267,9 @@ def test_calculation_easy_dataset(spark_fixture, easy_dataset): "missing_cells_perc": 0.0, "duplicate_rows": 0, "duplicate_rows_perc": 0.0, - "n_variables": 7, + "n_variables": 8, "n_observations": 7, - "numeric": 4, + "numeric": 5, "categorical": 2, "datetime": 1, }, @@ -1486,45 +1432,31 @@ def test_calculation_dataset_cat_missing(spark_fixture, dataset_cat_missing): updated_at=str(datetime.datetime.now()), ) - current_dataset, reference_dataset = dataset_cat_missing - current_dataset = apply_schema_to_dataframe( - current_dataset, model.to_current_spark_schema() - ) - current_dataset = current_dataset.select( - *[ - c - for c in model.to_current_spark_schema().names - if c in current_dataset.columns - ] - ) - reference_dataset = apply_schema_to_dataframe( - reference_dataset, model.to_reference_spark_schema() - ) - reference_dataset = reference_dataset.select( - *[ - c - for c in model.to_reference_spark_schema().names - if c in current_dataset.columns - ] + raw_current_dataset, raw_reference_dataset = dataset_cat_missing + current_dataset = CurrentDataset(model=model, raw_dataframe=raw_current_dataset) + reference_dataset = ReferenceDataset( + model=model, raw_dataframe=raw_reference_dataset ) + metrics_service = CurrentMetricsService( spark_session=spark_fixture, - current=current_dataset, - reference=reference_dataset, + current=current_dataset.current, + reference=reference_dataset.reference, model=model, ) - stats = metrics_service.calculate_statistics() + + stats = calculate_statistics_current(current_dataset) data_quality = metrics_service.calculate_data_quality() assert stats == my_approx( { "missing_cells": 5, - "missing_cells_perc": 7.142857142857142, + "missing_cells_perc": 6.25, "duplicate_rows": 2, "duplicate_rows_perc": 20.0, - "n_variables": 7, + "n_variables": 8, "n_observations": 10, - "numeric": 4, + "numeric": 5, "categorical": 2, "datetime": 1, } @@ -1702,34 +1634,20 @@ def test_calculation_dataset_with_datetime(spark_fixture, dataset_with_datetime) updated_at=str(datetime.datetime.now()), ) - current_dataset, reference_dataset = dataset_with_datetime - current_dataset = apply_schema_to_dataframe( - current_dataset, model.to_current_spark_schema() - ) - current_dataset = current_dataset.select( - *[ - c - for c in model.to_current_spark_schema().names - if c in current_dataset.columns - ] - ) - reference_dataset = apply_schema_to_dataframe( - reference_dataset, model.to_reference_spark_schema() - ) - reference_dataset = reference_dataset.select( - *[ - c - for c in model.to_reference_spark_schema().names - if c in current_dataset.columns - ] + raw_current_dataset, raw_reference_dataset = dataset_with_datetime + current_dataset = CurrentDataset(model=model, raw_dataframe=raw_current_dataset) + reference_dataset = ReferenceDataset( + model=model, raw_dataframe=raw_reference_dataset ) + metrics_service = CurrentMetricsService( spark_session=spark_fixture, - current=current_dataset, - reference=reference_dataset, + current=current_dataset.current, + reference=reference_dataset.reference, model=model, ) - stats = metrics_service.calculate_statistics() + + stats = calculate_statistics_current(current_dataset) data_quality = metrics_service.calculate_data_quality() assert stats == my_approx( @@ -1739,10 +1657,10 @@ def test_calculation_dataset_with_datetime(spark_fixture, dataset_with_datetime) "duplicate_rows": 3, "duplicate_rows_perc": 30.0, "missing_cells": 3, - "missing_cells_perc": 4.285714285714286, + "missing_cells_perc": 3.75, "n_observations": 10, - "n_variables": 7, - "numeric": 4, + "n_variables": 8, + "numeric": 5, } ) @@ -1918,34 +1836,20 @@ def test_calculation_easy_dataset_bucket_test(spark_fixture, easy_dataset_bucket updated_at=str(datetime.datetime.now()), ) - current_dataset, reference_dataset = easy_dataset_bucket_test - current_dataset = apply_schema_to_dataframe( - current_dataset, model.to_current_spark_schema() - ) - current_dataset = current_dataset.select( - *[ - c - for c in model.to_current_spark_schema().names - if c in current_dataset.columns - ] - ) - reference_dataset = apply_schema_to_dataframe( - reference_dataset, model.to_reference_spark_schema() - ) - reference_dataset = reference_dataset.select( - *[ - c - for c in model.to_reference_spark_schema().names - if c in current_dataset.columns - ] + raw_current_dataset, raw_reference_dataset = easy_dataset_bucket_test + current_dataset = CurrentDataset(model=model, raw_dataframe=raw_current_dataset) + reference_dataset = ReferenceDataset( + model=model, raw_dataframe=raw_reference_dataset ) + metrics_service = CurrentMetricsService( spark_session=spark_fixture, - current=current_dataset, - reference=reference_dataset, + current=current_dataset.current, + reference=reference_dataset.reference, model=model, ) - stats = metrics_service.calculate_statistics() + + stats = calculate_statistics_current(current_dataset) data_quality = metrics_service.calculate_data_quality() assert stats == my_approx( @@ -1954,9 +1858,9 @@ def test_calculation_easy_dataset_bucket_test(spark_fixture, easy_dataset_bucket "missing_cells_perc": 0.0, "duplicate_rows": 0, "duplicate_rows_perc": 0.0, - "n_variables": 7, + "n_variables": 8, "n_observations": 7, - "numeric": 4, + "numeric": 5, "categorical": 2, "datetime": 1, }, @@ -2147,34 +2051,20 @@ def test_calculation_for_hour(spark_fixture, dataset_for_hour): updated_at=str(datetime.datetime.now()), ) - current_dataset, reference_dataset = dataset_for_hour - current_dataset = apply_schema_to_dataframe( - current_dataset, model.to_current_spark_schema() - ) - current_dataset = current_dataset.select( - *[ - c - for c in model.to_current_spark_schema().names - if c in current_dataset.columns - ] - ) - reference_dataset = apply_schema_to_dataframe( - reference_dataset, model.to_reference_spark_schema() - ) - reference_dataset = reference_dataset.select( - *[ - c - for c in model.to_reference_spark_schema().names - if c in current_dataset.columns - ] + raw_current_dataset, raw_reference_dataset = dataset_for_hour + current_dataset = CurrentDataset(model=model, raw_dataframe=raw_current_dataset) + reference_dataset = ReferenceDataset( + model=model, raw_dataframe=raw_reference_dataset ) + metrics_service = CurrentMetricsService( spark_session=spark_fixture, - current=current_dataset, - reference=reference_dataset, + current=current_dataset.current, + reference=reference_dataset.reference, model=model, ) - stats = metrics_service.calculate_statistics() + + stats = calculate_statistics_current(current_dataset) data_quality = metrics_service.calculate_data_quality() model_quality = metrics_service.calculate_model_quality_with_group_by_timestamp() @@ -2185,10 +2075,10 @@ def test_calculation_for_hour(spark_fixture, dataset_for_hour): "duplicate_rows": 3, "duplicate_rows_perc": 30.0, "missing_cells": 3, - "missing_cells_perc": 4.285714285714286, + "missing_cells_perc": 3.75, "n_observations": 10, - "n_variables": 7, - "numeric": 4, + "n_variables": 8, + "numeric": 5, } ) @@ -2492,34 +2382,20 @@ def test_calculation_for_day(spark_fixture, dataset_for_day): updated_at=str(datetime.datetime.now()), ) - current_dataset, reference_dataset = dataset_for_day - current_dataset = apply_schema_to_dataframe( - current_dataset, model.to_current_spark_schema() - ) - current_dataset = current_dataset.select( - *[ - c - for c in model.to_current_spark_schema().names - if c in current_dataset.columns - ] - ) - reference_dataset = apply_schema_to_dataframe( - reference_dataset, model.to_reference_spark_schema() - ) - reference_dataset = reference_dataset.select( - *[ - c - for c in model.to_reference_spark_schema().names - if c in current_dataset.columns - ] + raw_current_dataset, raw_reference_dataset = dataset_for_day + current_dataset = CurrentDataset(model=model, raw_dataframe=raw_current_dataset) + reference_dataset = ReferenceDataset( + model=model, raw_dataframe=raw_reference_dataset ) + metrics_service = CurrentMetricsService( spark_session=spark_fixture, - current=current_dataset, - reference=reference_dataset, + current=current_dataset.current, + reference=reference_dataset.reference, model=model, ) - stats = metrics_service.calculate_statistics() + + stats = calculate_statistics_current(current_dataset) data_quality = metrics_service.calculate_data_quality() model_quality = metrics_service.calculate_model_quality_with_group_by_timestamp() @@ -2530,10 +2406,10 @@ def test_calculation_for_day(spark_fixture, dataset_for_day): "duplicate_rows": 3, "duplicate_rows_perc": 30.0, "missing_cells": 3, - "missing_cells_perc": 4.285714285714286, + "missing_cells_perc": 3.75, "n_observations": 10, - "n_variables": 7, - "numeric": 4, + "n_variables": 8, + "numeric": 5, } ) @@ -2823,34 +2699,20 @@ def test_calculation_for_week(spark_fixture, dataset_for_week): updated_at=str(datetime.datetime.now()), ) - current_dataset, reference_dataset = dataset_for_week - current_dataset = apply_schema_to_dataframe( - current_dataset, model.to_current_spark_schema() - ) - current_dataset = current_dataset.select( - *[ - c - for c in model.to_current_spark_schema().names - if c in current_dataset.columns - ] - ) - reference_dataset = apply_schema_to_dataframe( - reference_dataset, model.to_reference_spark_schema() - ) - reference_dataset = reference_dataset.select( - *[ - c - for c in model.to_reference_spark_schema().names - if c in current_dataset.columns - ] + raw_current_dataset, raw_reference_dataset = dataset_for_week + current_dataset = CurrentDataset(model=model, raw_dataframe=raw_current_dataset) + reference_dataset = ReferenceDataset( + model=model, raw_dataframe=raw_reference_dataset ) + metrics_service = CurrentMetricsService( spark_session=spark_fixture, - current=current_dataset, - reference=reference_dataset, + current=current_dataset.current, + reference=reference_dataset.reference, model=model, ) - stats = metrics_service.calculate_statistics() + + stats = calculate_statistics_current(current_dataset) data_quality = metrics_service.calculate_data_quality() model_quality = metrics_service.calculate_model_quality_with_group_by_timestamp() @@ -2861,10 +2723,10 @@ def test_calculation_for_week(spark_fixture, dataset_for_week): "duplicate_rows": 3, "duplicate_rows_perc": 30.0, "missing_cells": 3, - "missing_cells_perc": 4.285714285714286, + "missing_cells_perc": 3.75, "n_observations": 10, - "n_variables": 7, - "numeric": 4, + "n_variables": 8, + "numeric": 5, } ) @@ -3154,34 +3016,20 @@ def test_calculation_for_month(spark_fixture, dataset_for_month): updated_at=str(datetime.datetime.now()), ) - current_dataset, reference_dataset = dataset_for_month - current_dataset = apply_schema_to_dataframe( - current_dataset, model.to_current_spark_schema() - ) - current_dataset = current_dataset.select( - *[ - c - for c in model.to_current_spark_schema().names - if c in current_dataset.columns - ] - ) - reference_dataset = apply_schema_to_dataframe( - reference_dataset, model.to_reference_spark_schema() - ) - reference_dataset = reference_dataset.select( - *[ - c - for c in model.to_reference_spark_schema().names - if c in current_dataset.columns - ] + raw_current_dataset, raw_reference_dataset = dataset_for_month + current_dataset = CurrentDataset(model=model, raw_dataframe=raw_current_dataset) + reference_dataset = ReferenceDataset( + model=model, raw_dataframe=raw_reference_dataset ) + metrics_service = CurrentMetricsService( spark_session=spark_fixture, - current=current_dataset, - reference=reference_dataset, + current=current_dataset.current, + reference=reference_dataset.reference, model=model, ) - stats = metrics_service.calculate_statistics() + + stats = calculate_statistics_current(current_dataset) data_quality = metrics_service.calculate_data_quality() model_quality = metrics_service.calculate_model_quality_with_group_by_timestamp() @@ -3192,10 +3040,10 @@ def test_calculation_for_month(spark_fixture, dataset_for_month): "duplicate_rows": 3, "duplicate_rows_perc": 30.0, "missing_cells": 3, - "missing_cells_perc": 4.285714285714286, + "missing_cells_perc": 3.75, "n_observations": 10, - "n_variables": 7, - "numeric": 4, + "n_variables": 8, + "numeric": 5, } ) diff --git a/spark/tests/drift_test.py b/spark/tests/binary_drift_test.py similarity index 80% rename from spark/tests/drift_test.py rename to spark/tests/binary_drift_test.py index f6b2c873..4f4ad5fd 100644 --- a/spark/tests/drift_test.py +++ b/spark/tests/binary_drift_test.py @@ -6,6 +6,8 @@ import pytest from pyspark.sql import SparkSession +from jobs.models.current_dataset import CurrentDataset +from jobs.models.reference_dataset import ReferenceDataset from jobs.utils.current import CurrentMetricsService from jobs.utils.models import ( ModelOut, @@ -16,7 +18,6 @@ SupportedTypes, Granularity, ) -from jobs.utils.spark import apply_schema_to_dataframe test_resource_path = Path(__file__).resolve().parent / "resources" @@ -112,31 +113,15 @@ def test_drift(spark_fixture, drift_dataset): updated_at=str(datetime.datetime.now()), ) - current_dataset, reference_dataset = drift_dataset - current_dataset = apply_schema_to_dataframe( - current_dataset, model.to_current_spark_schema() - ) - current_dataset = current_dataset.select( - *[ - c - for c in model.to_current_spark_schema().names - if c in current_dataset.columns - ] - ) - reference_dataset = apply_schema_to_dataframe( - reference_dataset, model.to_reference_spark_schema() - ) - reference_dataset = reference_dataset.select( - *[ - c - for c in model.to_reference_spark_schema().names - if c in current_dataset.columns - ] + raw_current_dataset, raw_reference_dataset = drift_dataset + current_dataset = CurrentDataset(model=model, raw_dataframe=raw_current_dataset) + reference_dataset = ReferenceDataset( + model=model, raw_dataframe=raw_reference_dataset ) metrics_service = CurrentMetricsService( spark_session=spark_fixture, - current=current_dataset, - reference=reference_dataset, + current=current_dataset.current, + reference=reference_dataset.reference, model=model, ) @@ -214,35 +199,20 @@ def test_drift_small(spark_fixture, drift_small_dataset): updated_at=str(datetime.datetime.now()), ) - current_dataset, reference_dataset = drift_small_dataset - current_dataset = apply_schema_to_dataframe( - current_dataset, model.to_current_spark_schema() - ) - current_dataset = current_dataset.select( - *[ - c - for c in model.to_current_spark_schema().names - if c in current_dataset.columns - ] - ) - reference_dataset = apply_schema_to_dataframe( - reference_dataset, model.to_reference_spark_schema() - ) - reference_dataset = reference_dataset.select( - *[ - c - for c in model.to_reference_spark_schema().names - if c in current_dataset.columns - ] + raw_current_dataset, raw_reference_dataset = drift_small_dataset + current_dataset = CurrentDataset(model=model, raw_dataframe=raw_current_dataset) + reference_dataset = ReferenceDataset( + model=model, raw_dataframe=raw_reference_dataset ) metrics_service = CurrentMetricsService( spark_session=spark_fixture, - current=current_dataset, - reference=reference_dataset, + current=current_dataset.current, + reference=reference_dataset.reference, model=model, ) drift = metrics_service.calculate_drift() + assert not deepdiff.DeepDiff( drift, { @@ -307,38 +277,20 @@ def test_drift_boolean(spark_fixture, drift_dataset_bool): updated_at=str(datetime.datetime.now()), ) - current_dataset, reference_dataset = drift_dataset_bool - current_dataset = apply_schema_to_dataframe( - current_dataset, model.to_current_spark_schema() - ) - current_dataset = current_dataset.select( - *[ - c - for c in model.to_current_spark_schema().names - if c in current_dataset.columns - ] - ) - reference_dataset = apply_schema_to_dataframe( - reference_dataset, model.to_reference_spark_schema() - ) - reference_dataset = reference_dataset.select( - *[ - c - for c in model.to_reference_spark_schema().names - if c in current_dataset.columns - ] + raw_current_dataset, raw_reference_dataset = drift_dataset_bool + current_dataset = CurrentDataset(model=model, raw_dataframe=raw_current_dataset) + reference_dataset = ReferenceDataset( + model=model, raw_dataframe=raw_reference_dataset ) metrics_service = CurrentMetricsService( spark_session=spark_fixture, - current=current_dataset, - reference=reference_dataset, + current=current_dataset.current, + reference=reference_dataset.reference, model=model, ) drift = metrics_service.calculate_drift() - print(drift) - assert not deepdiff.DeepDiff( drift, { @@ -411,31 +363,15 @@ def test_drift_bigger_file(spark_fixture, drift_dataset_bigger_file): updated_at=str(datetime.datetime.now()), ) - current_dataset, reference_dataset = drift_dataset_bigger_file - current_dataset = apply_schema_to_dataframe( - current_dataset, model.to_current_spark_schema() - ) - current_dataset = current_dataset.select( - *[ - c - for c in model.to_current_spark_schema().names - if c in current_dataset.columns - ] - ) - reference_dataset = apply_schema_to_dataframe( - reference_dataset, model.to_reference_spark_schema() - ) - reference_dataset = reference_dataset.select( - *[ - c - for c in model.to_reference_spark_schema().names - if c in current_dataset.columns - ] + raw_current_dataset, raw_reference_dataset = drift_dataset_bigger_file + current_dataset = CurrentDataset(model=model, raw_dataframe=raw_current_dataset) + reference_dataset = ReferenceDataset( + model=model, raw_dataframe=raw_reference_dataset ) metrics_service = CurrentMetricsService( spark_session=spark_fixture, - current=current_dataset, - reference=reference_dataset, + current=current_dataset.current, + reference=reference_dataset.reference, model=model, ) diff --git a/spark/tests/reference_test.py b/spark/tests/binary_reference_test.py similarity index 96% rename from spark/tests/reference_test.py rename to spark/tests/binary_reference_test.py index 6d5e2fa8..a4b43417 100644 --- a/spark/tests/reference_test.py +++ b/spark/tests/binary_reference_test.py @@ -7,6 +7,8 @@ import pytest from pyspark.sql import SparkSession +from jobs.metrics.statistics import calculate_statistics_reference +from jobs.models.reference_dataset import ReferenceDataset from jobs.utils.models import ( ModelOut, ModelType, @@ -17,7 +19,6 @@ Granularity, ) from jobs.utils.reference import ReferenceMetricsService -from jobs.utils.spark import apply_schema_to_dataframe from tests.utils.pytest_utils import my_approx test_resource_path = Path(__file__).resolve().parent / "resources" @@ -122,19 +123,10 @@ def test_calculation(spark_fixture, dataset): updated_at=str(datetime.datetime.now()), ) - reference_dataset = apply_schema_to_dataframe( - dataset, model.to_reference_spark_schema() - ) - reference_dataset = reference_dataset.select( - *[ - c - for c in model.to_reference_spark_schema().names - if c in reference_dataset.columns - ] - ) - metrics_service = ReferenceMetricsService(reference_dataset, model=model) + reference_dataset = ReferenceDataset(model=model, raw_dataframe=dataset) + metrics_service = ReferenceMetricsService(reference_dataset.reference, model=model) - stats = metrics_service.calculate_statistics() + stats = calculate_statistics_reference(reference_dataset) model_quality = metrics_service.calculate_model_quality() data_quality = metrics_service.calculate_data_quality() @@ -351,19 +343,10 @@ def test_calculation_reference_joined(spark_fixture, reference_joined): updated_at=str(datetime.datetime.now()), ) - reference_dataset = apply_schema_to_dataframe( - reference_joined, model.to_reference_spark_schema() - ) - reference_dataset = reference_dataset.select( - *[ - c - for c in model.to_reference_spark_schema().names - if c in reference_dataset.columns - ] - ) + reference_dataset = ReferenceDataset(model=model, raw_dataframe=reference_joined) + metrics_service = ReferenceMetricsService(reference_dataset.reference, model=model) - metrics_service = ReferenceMetricsService(reference_dataset, model=model) - stats = metrics_service.calculate_statistics() + stats = calculate_statistics_reference(reference_dataset) model_quality = metrics_service.calculate_model_quality() data_quality = metrics_service.calculate_data_quality() @@ -946,18 +929,10 @@ def test_calculation_complete(spark_fixture, complete_dataset): updated_at=str(datetime.datetime.now()), ) - reference_dataset = apply_schema_to_dataframe( - complete_dataset, model.to_reference_spark_schema() - ) - reference_dataset = reference_dataset.select( - *[ - c - for c in model.to_reference_spark_schema().names - if c in reference_dataset.columns - ] - ) - metrics_service = ReferenceMetricsService(reference_dataset, model=model) - stats = metrics_service.calculate_statistics() + reference_dataset = ReferenceDataset(model=model, raw_dataframe=complete_dataset) + metrics_service = ReferenceMetricsService(reference_dataset.reference, model=model) + + stats = calculate_statistics_reference(reference_dataset) model_quality = metrics_service.calculate_model_quality() data_quality = metrics_service.calculate_data_quality() @@ -1147,18 +1122,10 @@ def test_calculation_easy_dataset(spark_fixture, easy_dataset): updated_at=str(datetime.datetime.now()), ) - reference_dataset = apply_schema_to_dataframe( - easy_dataset, model.to_reference_spark_schema() - ) - reference_dataset = reference_dataset.select( - *[ - c - for c in model.to_reference_spark_schema().names - if c in reference_dataset.columns - ] - ) - metrics_service = ReferenceMetricsService(reference_dataset, model=model) - stats = metrics_service.calculate_statistics() + reference_dataset = ReferenceDataset(model=model, raw_dataframe=easy_dataset) + metrics_service = ReferenceMetricsService(reference_dataset.reference, model=model) + + stats = calculate_statistics_reference(reference_dataset) model_quality = metrics_service.calculate_model_quality() data_quality = metrics_service.calculate_data_quality() @@ -1347,18 +1314,10 @@ def test_calculation_dataset_cat_missing(spark_fixture, dataset_cat_missing): updated_at=str(datetime.datetime.now()), ) - reference_dataset = apply_schema_to_dataframe( - dataset_cat_missing, model.to_reference_spark_schema() - ) - reference_dataset = reference_dataset.select( - *[ - c - for c in model.to_reference_spark_schema().names - if c in reference_dataset.columns - ] - ) - metrics_service = ReferenceMetricsService(reference_dataset, model=model) - stats = metrics_service.calculate_statistics() + reference_dataset = ReferenceDataset(model=model, raw_dataframe=dataset_cat_missing) + metrics_service = ReferenceMetricsService(reference_dataset.reference, model=model) + + stats = calculate_statistics_reference(reference_dataset) model_quality = metrics_service.calculate_model_quality() data_quality = metrics_service.calculate_data_quality() @@ -1568,18 +1527,12 @@ def test_calculation_dataset_with_datetime(spark_fixture, dataset_with_datetime) updated_at=str(datetime.datetime.now()), ) - reference_dataset = apply_schema_to_dataframe( - dataset_with_datetime, model.to_reference_spark_schema() + reference_dataset = ReferenceDataset( + model=model, raw_dataframe=dataset_with_datetime ) - reference_dataset = reference_dataset.select( - *[ - c - for c in model.to_reference_spark_schema().names - if c in reference_dataset.columns - ] - ) - metrics_service = ReferenceMetricsService(reference_dataset, model=model) - stats = metrics_service.calculate_statistics() + metrics_service = ReferenceMetricsService(reference_dataset.reference, model=model) + + stats = calculate_statistics_reference(reference_dataset) model_quality = metrics_service.calculate_model_quality() data_quality = metrics_service.calculate_data_quality() @@ -1797,19 +1750,10 @@ def test_calculation_enhanced_data(spark_fixture, enhanced_data): updated_at=str(datetime.datetime.now()), ) - reference_dataset = apply_schema_to_dataframe( - enhanced_data, model.to_reference_spark_schema() - ) - reference_dataset = reference_dataset.select( - *[ - c - for c in model.to_reference_spark_schema().names - if c in reference_dataset.columns - ] - ) - metrics_service = ReferenceMetricsService(reference_dataset, model=model) + reference_dataset = ReferenceDataset(model=model, raw_dataframe=enhanced_data) + metrics_service = ReferenceMetricsService(reference_dataset.reference, model=model) - stats = metrics_service.calculate_statistics() + stats = calculate_statistics_reference(reference_dataset) model_quality = metrics_service.calculate_model_quality() data_quality = metrics_service.calculate_data_quality() @@ -2534,18 +2478,12 @@ def test_calculation_dataset_bool_missing(spark_fixture, dataset_bool_missing): updated_at=str(datetime.datetime.now()), ) - reference_dataset = apply_schema_to_dataframe( - dataset_bool_missing, model.to_reference_spark_schema() + reference_dataset = ReferenceDataset( + model=model, raw_dataframe=dataset_bool_missing ) - reference_dataset = reference_dataset.select( - *[ - c - for c in model.to_reference_spark_schema().names - if c in reference_dataset.columns - ] - ) - metrics_service = ReferenceMetricsService(reference_dataset, model=model) - stats = metrics_service.calculate_statistics() + metrics_service = ReferenceMetricsService(reference_dataset.reference, model=model) + + stats = calculate_statistics_reference(reference_dataset) model_quality = metrics_service.calculate_model_quality() data_quality = metrics_service.calculate_data_quality() diff --git a/spark/tests/multiclass_reference_test.py b/spark/tests/multiclass_reference_test.py new file mode 100644 index 00000000..c711b727 --- /dev/null +++ b/spark/tests/multiclass_reference_test.py @@ -0,0 +1,146 @@ +import datetime +import uuid +from pathlib import Path + +import pytest +from pyspark.sql import SparkSession + +from jobs.metrics.statistics import calculate_statistics_reference +from jobs.models.reference_dataset import ReferenceDataset +from jobs.utils.models import ( + ModelOut, + ModelType, + DataType, + OutputType, + ColumnDefinition, + SupportedTypes, + Granularity, +) +from tests.utils.pytest_utils import my_approx + +test_resource_path = Path(__file__).resolve().parent / "resources" + + +@pytest.fixture() +def spark_fixture(): + spark = SparkSession.builder.appName("Reference Multiclass PyTest").getOrCreate() + yield spark + + +@pytest.fixture() +def dataset_target_int(spark_fixture): + yield spark_fixture.read.csv( + f"{test_resource_path}/reference/multiclass/dataset_target_int.csv", header=True + ) + + +@pytest.fixture() +def dataset_target_string(spark_fixture): + yield spark_fixture.read.csv( + f"{test_resource_path}/reference/multiclass/dataset_target_string.csv", + header=True, + ) + + +def test_calculation_dataset_target_int(spark_fixture, dataset_target_int): + output = OutputType( + prediction=ColumnDefinition(name="prediction", type=SupportedTypes.int), + prediction_proba=None, + output=[ColumnDefinition(name="prediction", type=SupportedTypes.int)], + ) + target = ColumnDefinition(name="target", type=SupportedTypes.int) + timestamp = ColumnDefinition(name="datetime", type=SupportedTypes.datetime) + granularity = Granularity.HOUR + features = [ + ColumnDefinition(name="cat1", type=SupportedTypes.string), + ColumnDefinition(name="cat2", type=SupportedTypes.string), + ColumnDefinition(name="num1", type=SupportedTypes.float), + ColumnDefinition(name="num2", type=SupportedTypes.float), + ] + model = ModelOut( + uuid=uuid.uuid4(), + name="model", + description="description", + model_type=ModelType.MULTI_CLASS, + data_type=DataType.TABULAR, + timestamp=timestamp, + granularity=granularity, + outputs=output, + target=target, + features=features, + frameworks="framework", + algorithm="algorithm", + created_at=str(datetime.datetime.now()), + updated_at=str(datetime.datetime.now()), + ) + + reference_dataset = ReferenceDataset(model=model, raw_dataframe=dataset_target_int) + + stats = calculate_statistics_reference(reference_dataset) + + assert stats == my_approx( + { + "categorical": 2, + "datetime": 1, + "duplicate_rows": 0, + "duplicate_rows_perc": 0.0, + "missing_cells": 3, + "missing_cells_perc": 4.285714285714286, + "n_observations": 10, + "n_variables": 7, + "numeric": 4, + } + ) + + +def test_calculation_dataset_target_string(spark_fixture, dataset_target_string): + output = OutputType( + prediction=ColumnDefinition(name="prediction", type=SupportedTypes.string), + prediction_proba=None, + output=[ColumnDefinition(name="prediction", type=SupportedTypes.string)], + ) + target = ColumnDefinition(name="target", type=SupportedTypes.string) + timestamp = ColumnDefinition(name="datetime", type=SupportedTypes.datetime) + granularity = Granularity.HOUR + features = [ + ColumnDefinition(name="cat1", type=SupportedTypes.string), + ColumnDefinition(name="cat2", type=SupportedTypes.string), + ColumnDefinition(name="num1", type=SupportedTypes.float), + ColumnDefinition(name="num2", type=SupportedTypes.float), + ] + model = ModelOut( + uuid=uuid.uuid4(), + name="model", + description="description", + model_type=ModelType.MULTI_CLASS, + data_type=DataType.TABULAR, + timestamp=timestamp, + granularity=granularity, + outputs=output, + target=target, + features=features, + frameworks="framework", + algorithm="algorithm", + created_at=str(datetime.datetime.now()), + updated_at=str(datetime.datetime.now()), + ) + + reference_dataset = ReferenceDataset( + model=model, raw_dataframe=dataset_target_string + ) + + stats = calculate_statistics_reference(reference_dataset) + + assert stats == my_approx( + { + "categorical": 4, + "datetime": 1, + "duplicate_rows": 0, + "duplicate_rows_perc": 0.0, + "missing_cells": 3, + "missing_cells_perc": 4.285714285714286, + "n_observations": 10, + "n_variables": 7, + "numeric": 2, + } + ) diff --git a/spark/tests/resources/reference/multiclass/dataset_target_int.csv b/spark/tests/resources/reference/multiclass/dataset_target_int.csv new file mode 100644 index 00000000..6f231d87 --- /dev/null +++ b/spark/tests/resources/reference/multiclass/dataset_target_int.csv @@ -0,0 +1,11 @@ +cat1,cat2,num1,num2,prediction,target,datetime +A,X,1.0,1.4,1,1,2024-06-16 00:01:00-05:00 +B,X,1.5,100.0,0,0,2024-06-16 00:02:00-05:00 +A,Y,3.0,123.0,1,1,2024-06-16 00:03:00-05:00 +B,X,0.5,,2,0,2024-06-16 00:04:00-05:00 +B,X,0.5,,3,2,2024-06-16 00:05:00-05:00 +B,X,,200.0,1,3,2024-06-16 00:06:00-05:00 +C,X,1.0,300.0,0,0,2024-06-16 00:07:00-05:00 +A,X,1.0,499.0,2,2,2024-06-16 00:08:00-05:00 +A,X,1.0,499.0,1,1,2024-06-16 00:09:00-05:00 +A,X,1.0,499.0,3,2,2024-06-16 00:10:00-05:00 \ No newline at end of file diff --git a/spark/tests/resources/reference/multiclass/dataset_target_string.csv b/spark/tests/resources/reference/multiclass/dataset_target_string.csv new file mode 100644 index 00000000..36c0ddaf --- /dev/null +++ b/spark/tests/resources/reference/multiclass/dataset_target_string.csv @@ -0,0 +1,11 @@ +cat1,cat2,num1,num2,prediction,target,datetime +A,X,1.0,1.4,HEALTY,HEALTY,2024-06-16 00:01:00-05:00 +B,X,1.5,100.0,UNHEALTHY,UNHEALTHY,2024-06-16 00:02:00-05:00 +A,Y,3.0,123.0,HEALTY,HEALTY,2024-06-16 00:03:00-05:00 +B,X,0.5,,UNKNOWN,UNHEALTHY,2024-06-16 00:04:00-05:00 +B,X,0.5,,ORPHAN,UNKNOWN,2024-06-16 00:05:00-05:00 +B,X,,200.0,HEALTY,ORPHAN,2024-06-16 00:06:00-05:00 +C,X,1.0,300.0,UNHEALTHY,UNHEALTHY,2024-06-16 00:07:00-05:00 +A,X,1.0,499.0,UNKNOWN,UNKNOWN,2024-06-16 00:08:00-05:00 +A,X,1.0,499.0,HEALTY,HEALTY,2024-06-16 00:09:00-05:00 +A,X,1.0,499.0,ORPHAN,UNKNOWN,2024-06-16 00:10:00-05:00 \ No newline at end of file