Skip to content

Commit

Permalink
feat: refactor jobs to test output (#120)
Browse files Browse the repository at this point in the history
  • Loading branch information
SteZamboni authored Jul 18, 2024
1 parent 43fc389 commit 67e94e9
Show file tree
Hide file tree
Showing 6 changed files with 818 additions and 72 deletions.
81 changes: 46 additions & 35 deletions spark/jobs/current_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,41 +19,8 @@
from pyspark.sql import SparkSession


def main(
spark_session: SparkSession,
model: ModelOut,
current_dataset_path: str,
current_uuid: str,
reference_dataset_path: str,
table_name: str,
):
spark_context = spark_session.sparkContext

spark_context._jsc.hadoopConfiguration().set(
"fs.s3a.access.key", os.getenv("AWS_ACCESS_KEY_ID")
)
spark_context._jsc.hadoopConfiguration().set(
"fs.s3a.secret.key", os.getenv("AWS_SECRET_ACCESS_KEY")
)
spark_context._jsc.hadoopConfiguration().set(
"fs.s3a.endpoint.region", os.getenv("AWS_REGION")
)
if os.getenv("S3_ENDPOINT_URL"):
spark_context._jsc.hadoopConfiguration().set(
"fs.s3a.endpoint", os.getenv("S3_ENDPOINT_URL")
)
spark_context._jsc.hadoopConfiguration().set("fs.s3a.path.style.access", "true")
spark_context._jsc.hadoopConfiguration().set(
"fs.s3a.connection.ssl.enabled", "false"
)

raw_current = spark_session.read.csv(current_dataset_path, header=True)
current_dataset = CurrentDataset(model=model, raw_dataframe=raw_current)
raw_reference = spark_session.read.csv(reference_dataset_path, header=True)
reference_dataset = ReferenceDataset(model=model, raw_dataframe=raw_reference)

complete_record = {"UUID": str(uuid.uuid4()), "CURRENT_UUID": current_uuid}

def compute_metrics(spark_session, current_dataset, reference_dataset, model):
complete_record = {}
match model.model_type:
case ModelType.BINARY:
metrics_service = CurrentMetricsService(
Expand Down Expand Up @@ -118,6 +85,50 @@ def main(
)
complete_record["DRIFT"] = orjson.dumps(drift).decode("utf-8")

return complete_record


def main(
spark_session: SparkSession,
model: ModelOut,
current_dataset_path: str,
current_uuid: str,
reference_dataset_path: str,
table_name: str,
):
spark_context = spark_session.sparkContext

spark_context._jsc.hadoopConfiguration().set(
"fs.s3a.access.key", os.getenv("AWS_ACCESS_KEY_ID")
)
spark_context._jsc.hadoopConfiguration().set(
"fs.s3a.secret.key", os.getenv("AWS_SECRET_ACCESS_KEY")
)
spark_context._jsc.hadoopConfiguration().set(
"fs.s3a.endpoint.region", os.getenv("AWS_REGION")
)
if os.getenv("S3_ENDPOINT_URL"):
spark_context._jsc.hadoopConfiguration().set(
"fs.s3a.endpoint", os.getenv("S3_ENDPOINT_URL")
)
spark_context._jsc.hadoopConfiguration().set("fs.s3a.path.style.access", "true")
spark_context._jsc.hadoopConfiguration().set(
"fs.s3a.connection.ssl.enabled", "false"
)

raw_current = spark_session.read.csv(current_dataset_path, header=True)
current_dataset = CurrentDataset(model=model, raw_dataframe=raw_current)
raw_reference = spark_session.read.csv(reference_dataset_path, header=True)
reference_dataset = ReferenceDataset(model=model, raw_dataframe=raw_reference)

complete_record = compute_metrics(
spark_session=spark_session,
current_dataset=current_dataset,
reference_dataset=reference_dataset,
model=model,
)
complete_record.update({"UUID": str(uuid.uuid4()), "CURRENT_UUID": current_uuid})

schema = StructType(
[
StructField("UUID", StringType(), True),
Expand Down
72 changes: 40 additions & 32 deletions spark/jobs/reference_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,38 +19,8 @@
from utils.reference_multiclass import ReferenceMetricsMulticlassService


def main(
spark_session: SparkSession,
model: ModelOut,
reference_dataset_path: str,
reference_uuid: str,
table_name: str,
):
spark_context = spark_session.sparkContext

spark_context._jsc.hadoopConfiguration().set(
"fs.s3a.access.key", os.getenv("AWS_ACCESS_KEY_ID")
)
spark_context._jsc.hadoopConfiguration().set(
"fs.s3a.secret.key", os.getenv("AWS_SECRET_ACCESS_KEY")
)
spark_context._jsc.hadoopConfiguration().set(
"fs.s3a.endpoint.region", os.getenv("AWS_REGION")
)
if os.getenv("S3_ENDPOINT_URL"):
spark_context._jsc.hadoopConfiguration().set(
"fs.s3a.endpoint", os.getenv("S3_ENDPOINT_URL")
)
spark_context._jsc.hadoopConfiguration().set("fs.s3a.path.style.access", "true")
spark_context._jsc.hadoopConfiguration().set(
"fs.s3a.connection.ssl.enabled", "false"
)

raw_dataframe = spark_session.read.csv(reference_dataset_path, header=True)
reference_dataset = ReferenceDataset(model=model, raw_dataframe=raw_dataframe)

complete_record = {"UUID": str(uuid.uuid4()), "REFERENCE_UUID": reference_uuid}

def compute_metrics(reference_dataset, model):
complete_record = {}
match model.model_type:
case ModelType.BINARY:
metrics_service = ReferenceMetricsService(reference=reference_dataset)
Expand Down Expand Up @@ -99,6 +69,44 @@ def main(
complete_record["DATA_QUALITY"] = data_quality.model_dump_json(
serialize_as_any=True
)
return complete_record


def main(
spark_session: SparkSession,
model: ModelOut,
reference_dataset_path: str,
reference_uuid: str,
table_name: str,
):
spark_context = spark_session.sparkContext

spark_context._jsc.hadoopConfiguration().set(
"fs.s3a.access.key", os.getenv("AWS_ACCESS_KEY_ID")
)
spark_context._jsc.hadoopConfiguration().set(
"fs.s3a.secret.key", os.getenv("AWS_SECRET_ACCESS_KEY")
)
spark_context._jsc.hadoopConfiguration().set(
"fs.s3a.endpoint.region", os.getenv("AWS_REGION")
)
if os.getenv("S3_ENDPOINT_URL"):
spark_context._jsc.hadoopConfiguration().set(
"fs.s3a.endpoint", os.getenv("S3_ENDPOINT_URL")
)
spark_context._jsc.hadoopConfiguration().set("fs.s3a.path.style.access", "true")
spark_context._jsc.hadoopConfiguration().set(
"fs.s3a.connection.ssl.enabled", "false"
)

raw_dataframe = spark_session.read.csv(reference_dataset_path, header=True)
reference_dataset = ReferenceDataset(model=model, raw_dataframe=raw_dataframe)

complete_record = compute_metrics(reference_dataset, model)

complete_record.update(
{"UUID": str(uuid.uuid4()), "REFERENCE_UUID": reference_uuid}
)

schema = StructType(
[
Expand Down
Loading

0 comments on commit 67e94e9

Please sign in to comment.