Skip to content

Commit

Permalink
feat: new PSI value, chi2 new method fixed (#134)
Browse files Browse the repository at this point in the history
  • Loading branch information
SteZamboni authored Jul 24, 2024
1 parent 92342b9 commit d59f33c
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 108 deletions.
46 changes: 7 additions & 39 deletions spark/jobs/metrics/chi2.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,43 +178,6 @@ def __prepare_data_for_test(
)
return vector_data.select(reference_column, f"{current_column}_vector")

def __concatenate_dfs(self) -> pyspark.sql.DataFrame:
"""
Concatenates the reference and current dataframes if they have the same size or creates subsamples to make them
of equal size.
Returns:
- pyspark.sql.DataFrame: The concatenated DataFrame.
"""

self.current = self.current.withColumn("type", F.lit("current"))
self.reference = self.reference.withColumn("type", F.lit("reference"))

if self.__have_same_size():
concatenated_data = self.current.unionByName(self.reference)

else:
max_size = max(self.reference_size, self.current_size)

if self.reference_size == max_size:
# create a reference subsample with a size equal to the current
subsample_reference = self.spark_session.createDataFrame(
self.reference.rdd.takeSample(
withReplacement=True, num=self.current_size, seed=1990
)
)
concatenated_data = self.current.unionByName(subsample_reference)
else:
# create a current subsample with a size equal to the reference
subsample_current = self.spark_session.createDataFrame(
self.current.rdd.takeSample(
withReplacement=True, num=self.reference_size, seed=1990
)
)
concatenated_data = subsample_current.unionByName(self.reference)

return concatenated_data

def test_independence(self, reference_column, current_column) -> Dict:
"""
Performs the chi-square test of independence.
Expand Down Expand Up @@ -276,7 +239,10 @@ def test_goodness_fit(self, reference_column, current_column) -> Dict:
self.reference_size = self.reference.count()
self.current_size = self.current.count()

concatenated_data = self.__concatenate_dfs()
self.current = self.current.withColumn("type", F.lit("current"))
self.reference = self.reference.withColumn("type", F.lit("reference"))

concatenated_data = self.current.unionByName(self.reference)

def cnt_cond(cond):
return F.sum(F.when(cond, 1).otherwise(0))
Expand All @@ -295,5 +261,7 @@ def cnt_cond(cond):
.rdd.flatMap(lambda x: x)
.collect()
)
res = chisquare(ref_fr, cur_fr)
proportion = sum(cur_fr) / sum(ref_fr)
ref_fr = ref_fr * proportion
res = chisquare(cur_fr, ref_fr)
return {"pValue": float(res[1]), "statistic": float(res[0])}
65 changes: 22 additions & 43 deletions spark/jobs/metrics/drift_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,12 @@ def calculate_drift(
"type": "CHI2",
},
}
if (
reference_dataset.reference_count > 5
and current_dataset.current_count > 5
):
result_tmp = chi2.test_goodness_fit(column, column)
feature_dict_to_append["drift_calc"]["value"] = float(
result_tmp["pValue"]
)
feature_dict_to_append["drift_calc"]["has_drift"] = bool(
result_tmp["pValue"] <= 0.05
)
else:
feature_dict_to_append["drift_calc"]["value"] = None
feature_dict_to_append["drift_calc"]["has_drift"] = False
feature_dict_to_append["drift_calc"]["type"] = "CHI2"
result_tmp = chi2.test_goodness_fit(column, column)
feature_dict_to_append["drift_calc"]["value"] = float(result_tmp["pValue"])
feature_dict_to_append["drift_calc"]["has_drift"] = bool(
result_tmp["pValue"] <= 0.05
)
drift_result["feature_metrics"].append(feature_dict_to_append)

float_features = [
Expand Down Expand Up @@ -86,20 +78,13 @@ def calculate_drift(
]
if len(unique_values_tot) < 15:
feature_dict_to_append["drift_calc"]["type"] = "CHI2"
if (
reference_dataset.reference_count > 5
and current_dataset.current_count > 5
):
result_tmp = chi2.test_goodness_fit(column, column)
feature_dict_to_append["drift_calc"]["value"] = float(
result_tmp["pValue"]
)
feature_dict_to_append["drift_calc"]["has_drift"] = bool(
result_tmp["pValue"] <= 0.05
)
else:
feature_dict_to_append["drift_calc"]["value"] = None
feature_dict_to_append["drift_calc"]["has_drift"] = False
result_tmp = chi2.test_goodness_fit(column, column)
feature_dict_to_append["drift_calc"]["value"] = float(
result_tmp["pValue"]
)
feature_dict_to_append["drift_calc"]["has_drift"] = bool(
result_tmp["pValue"] <= 0.05
)
else:
feature_dict_to_append["drift_calc"]["type"] = "KS"
result_tmp = ks.test(column, column)
Expand Down Expand Up @@ -145,28 +130,22 @@ def calculate_drift(
]
if len(unique_values_tot) < 15:
feature_dict_to_append["drift_calc"]["type"] = "CHI2"
if (
reference_dataset.reference_count > 5
and current_dataset.current_count > 5
):
result_tmp = chi2.test_goodness_fit(column, column)
feature_dict_to_append["drift_calc"]["value"] = float(
result_tmp["pValue"]
)
feature_dict_to_append["drift_calc"]["has_drift"] = bool(
result_tmp["pValue"] <= 0.05
)
else:
feature_dict_to_append["drift_calc"]["value"] = None
feature_dict_to_append["drift_calc"]["has_drift"] = False
feature_dict_to_append["drift_calc"]["type"] = "CHI2"
result_tmp = chi2.test_goodness_fit(column, column)
feature_dict_to_append["drift_calc"]["value"] = float(
result_tmp["pValue"]
)
feature_dict_to_append["drift_calc"]["has_drift"] = bool(
result_tmp["pValue"] <= 0.05
)
else:
feature_dict_to_append["drift_calc"]["type"] = "PSI"
result_tmp = psi_obj.calculate_psi(column)
feature_dict_to_append["drift_calc"]["value"] = float(
result_tmp["psi_value"]
)
feature_dict_to_append["drift_calc"]["has_drift"] = bool(
result_tmp["psi_value"] >= 0.2
result_tmp["psi_value"] >= 0.1
)
drift_result["feature_metrics"].append(feature_dict_to_append)

Expand Down
76 changes: 56 additions & 20 deletions spark/tests/results/drift_calculator_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"feature_name": "cat2",
"drift_calc": {
"type": "CHI2",
"value": 3.4671076872696563e-05,
"value": 2.5396285894708634e-10,
"has_drift": True,
},
},
Expand All @@ -27,19 +27,31 @@
"feature_metrics": [
{
"feature_name": "cat1",
"drift_calc": {"type": "CHI2", "value": None, "has_drift": False},
"drift_calc": {
"type": "CHI2",
"value": 0.7788007830714049,
"has_drift": False,
},
},
{
"feature_name": "cat2",
"drift_calc": {"type": "CHI2", "value": None, "has_drift": False},
"drift_calc": {
"type": "CHI2",
"value": 0.007660761135179449,
"has_drift": True,
},
},
{
"feature_name": "num1",
"drift_calc": {"type": "CHI2", "value": None, "has_drift": False},
"drift_calc": {"type": "CHI2", "value": 0.0, "has_drift": True},
},
{
"feature_name": "num2",
"drift_calc": {"type": "CHI2", "value": None, "has_drift": False},
"drift_calc": {
"type": "CHI2",
"value": 0.4158801869955079,
"has_drift": False,
},
},
]
}
Expand Down Expand Up @@ -69,14 +81,18 @@
"feature_metrics": [
{
"feature_name": "cat1",
"drift_calc": {"type": "CHI2", "value": 0.0, "has_drift": True},
"drift_calc": {
"type": "CHI2",
"value": 0.7074036474040617,
"has_drift": False,
},
},
{
"feature_name": "cat2",
"drift_calc": {
"type": "CHI2",
"value": 0.052807511416113395,
"has_drift": False,
"value": 1.3668274623882378e-07,
"has_drift": True,
},
},
{
Expand All @@ -85,7 +101,11 @@
},
{
"feature_name": "num2",
"drift_calc": {"type": "CHI2", "value": 0.0, "has_drift": True},
"drift_calc": {
"type": "CHI2",
"value": 0.9282493523958153,
"has_drift": False,
},
},
]
}
Expand All @@ -96,8 +116,8 @@
"feature_name": "weathersit",
"drift_calc": {
"type": "CHI2",
"value": 0.002631773674724352,
"has_drift": True,
"value": 0.5328493415823949,
"has_drift": False,
},
},
{
Expand All @@ -118,33 +138,49 @@
},
{
"feature_name": "season",
"drift_calc": {"type": "CHI2", "value": 0.0, "has_drift": True},
"drift_calc": {
"type": "CHI2",
"value": 1.5727996539817032e-36,
"has_drift": True,
},
},
{
"feature_name": "yr",
"drift_calc": {"type": "CHI2", "value": 0.0, "has_drift": True},
"drift_calc": {
"type": "CHI2",
"value": 1.3270931223367946e-23,
"has_drift": True,
},
},
{
"feature_name": "mnth",
"drift_calc": {"type": "CHI2", "value": 0.0, "has_drift": True},
"drift_calc": {
"type": "CHI2",
"value": 1.1116581506687278e-44,
"has_drift": True,
},
},
{
"feature_name": "holiday",
"drift_calc": {"type": "CHI2", "value": 1.0, "has_drift": False},
"drift_calc": {
"type": "CHI2",
"value": 0.6115640463654775,
"has_drift": False,
},
},
{
"feature_name": "weekday",
"drift_calc": {
"type": "CHI2",
"value": 0.7855334068007708,
"value": 0.9998169413361089,
"has_drift": False,
},
},
{
"feature_name": "workingday",
"drift_calc": {
"type": "CHI2",
"value": 0.6625205835400574,
"value": 0.730645812540401,
"has_drift": False,
},
},
Expand All @@ -165,7 +201,7 @@
"feature_name": "has_5g",
"drift_calc": {
"type": "CHI2",
"value": 0.6528883189652503,
"value": 0.652356328876868,
"has_drift": False,
},
},
Expand Down Expand Up @@ -237,15 +273,15 @@
"feature_name": "internal_memory",
"drift_calc": {
"type": "CHI2",
"value": 0.9999999717288575,
"value": 0.9999999710826085,
"has_drift": False,
},
},
{
"feature_name": "refresh_rate",
"drift_calc": {
"type": "CHI2",
"value": 0.9997655785111437,
"value": 0.9997690818736329,
"has_drift": False,
},
},
Expand Down
Loading

0 comments on commit d59f33c

Please sign in to comment.