Skip to content

Commit

Permalink
chore: fix cyberml tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mhamilton723 committed Apr 3, 2023
1 parent 25ba729 commit 0eaa8b8
Showing 1 changed file with 13 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -136,17 +136,17 @@ def __init__(self):
intra_test_pdf = factory.create_clustered_intra_test_data(training_pdf)
inter_test_pdf = factory.create_clustered_inter_test_data()

curr_training = sc.createDataFrame(training_pdf).withColumn(
curr_training = spark.createDataFrame(training_pdf).withColumn(
AccessAnomalyConfig.default_tenant_col,
f.lit(tid),
)

curr_intra_test = sc.createDataFrame(intra_test_pdf).withColumn(
curr_intra_test = spark.createDataFrame(intra_test_pdf).withColumn(
AccessAnomalyConfig.default_tenant_col,
f.lit(tid),
)

curr_inter_test = sc.createDataFrame(inter_test_pdf).withColumn(
curr_inter_test = spark.createDataFrame(inter_test_pdf).withColumn(
AccessAnomalyConfig.default_tenant_col,
f.lit(tid),
)
Expand Down Expand Up @@ -212,7 +212,7 @@ def create_new_training(ratio: float) -> DataFrame:
training_pdf = create_data_factory().create_clustered_training_data(ratio)

return materialized_cache(
sc.createDataFrame(training_pdf).withColumn(
spark.createDataFrame(training_pdf).withColumn(
AccessAnomalyConfig.default_tenant_col,
f.lit(0),
),
Expand Down Expand Up @@ -274,18 +274,18 @@ def test_model_standard_scaling(self):
)

df = materialized_cache(
sc.createDataFrame(
spark.createDataFrame(
[["0", "roy", "res1", 4.0], ["0", "roy", "res2", 8.0]],
df_schema,
),
)

user_mapping_df = materialized_cache(
sc.createDataFrame([["0", "roy", [1.0, 1.0, 0.0, 1.0]]], user_model_schema),
spark.createDataFrame([["0", "roy", [1.0, 1.0, 0.0, 1.0]]], user_model_schema),
)

res_mapping_df = materialized_cache(
sc.createDataFrame(
spark.createDataFrame(
[
["0", "res1", [2.0, 2.0, 1.0, 0.0]],
["0", "res2", [4.0, 4.0, 1.0, 0.0]],
Expand Down Expand Up @@ -377,7 +377,7 @@ def test_model_end2end(self):
)

user_mapping_df = materialized_cache(
sc.createDataFrame(
spark.createDataFrame(
[
[
0,
Expand All @@ -391,7 +391,7 @@ def test_model_end2end(self):
)

res_mapping_df = materialized_cache(
sc.createDataFrame(
spark.createDataFrame(
[
[
0,
Expand All @@ -414,7 +414,7 @@ def test_model_end2end(self):
assert df.count() == num_users * num_resources

user_mapping_df = materialized_cache(
sc.createDataFrame(
spark.createDataFrame(
[
[
0,
Expand All @@ -428,7 +428,7 @@ def test_model_end2end(self):
)

res_mapping_df = materialized_cache(
sc.createDataFrame(
spark.createDataFrame(
[
[
0,
Expand Down Expand Up @@ -897,7 +897,7 @@ def test_simple(self):
],
)

df = sc.createDataFrame(
df = spark.createDataFrame(
[
["0", "user0", "res0", 4.0],
["0", "user1", "res0", 8.0],
Expand Down Expand Up @@ -944,7 +944,7 @@ def test_datafactory(self):
user_col = "user"
res_col = "res"

df = sc.createDataFrame(
df = spark.createDataFrame(
DataFactory(single_component=False).create_clustered_training_data(),
).withColumn(tenant_col, f.lit(0))

Expand Down

0 comments on commit 0eaa8b8

Please sign in to comment.