Skip to content

Commit

Permalink
Add coverage for csvs module.
Browse files Browse the repository at this point in the history
  • Loading branch information
calina-c committed Jan 5, 2024
1 parent 7d5c769 commit d205aa5
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 44 deletions.
78 changes: 34 additions & 44 deletions pdr_backend/util/csvs.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,12 @@ def check_and_create_dir(dir_path: str):


@enforce_types
def save_prediction_csv(all_predictions: List[Prediction], csv_output_dir: str):
def _save_prediction_csv(
all_predictions: List[Prediction],
csv_output_dir: str,
headers: List,
attribute_names: List,
):
check_and_create_dir(csv_output_dir)

data = generate_prediction_data_structure(all_predictions)
Expand All @@ -56,58 +61,43 @@ def save_prediction_csv(all_predictions: List[Prediction], csv_output_dir: str):
with open(filename, "w", newline="") as file:
writer = csv.writer(file)

writer.writerow(
["Predicted Value", "True Value", "Timestamp", "Stake", "Payout"]
)
writer.writerow(headers)

for prediction in predictions:
writer.writerow(
[
prediction.prediction,
prediction.trueval,
prediction.timestamp,
prediction.stake,
prediction.payout,
getattr(prediction, attribute_name)
for attribute_name in attribute_names
]
)

print(f"CSV file '{filename}' created successfully.")


@enforce_types
def save_analysis_csv(all_predictions: List[Prediction], csv_output_dir: str):
check_and_create_dir(csv_output_dir)

data = generate_prediction_data_structure(all_predictions)
def save_prediction_csv(all_predictions: List[Prediction], csv_output_dir: str):
_save_prediction_csv(
all_predictions,
csv_output_dir,
["Predicted Value", "True Value", "Timestamp", "Stake", "Payout"],
["prediction", "trueval", "timestamp", "stake", "payout"],
)

for key, predictions in data.items():
predictions.sort(key=lambda x: x.timestamp)
filename = key_csv_filename_with_dir(csv_output_dir, key)
with open(filename, "w", newline="") as file:
writer = csv.writer(file)
writer.writerow(
[
"PredictionID",
"Timestamp",
"Slot",
"Stake",
"Wallet",
"Payout",
"True Value",
"Predicted Value",
]
)

for prediction in predictions:
writer.writerow(
[
prediction.ID,
prediction.timestamp,
prediction.slot,
prediction.stake,
prediction.user,
prediction.payout,
prediction.trueval,
prediction.prediction,
]
)
print(f"CSV file '{filename}' created successfully.")
@enforce_types
def save_analysis_csv(all_predictions: List[Prediction], csv_output_dir: str):
_save_prediction_csv(
all_predictions,
csv_output_dir,
[
"PredictionID",
"Timestamp",
"Slot",
"Stake",
"Wallet",
"Payout",
"True Value",
"Predicted Value",
],
["ID", "timestamp", "slot", "stake", "user", "payout", "trueval", "prediction"],
)
58 changes: 58 additions & 0 deletions pdr_backend/util/test_noganache/test_csvs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import csv
import os

from pdr_backend.subgraph.prediction import mock_daily_predictions
from pdr_backend.util.csvs import save_analysis_csv, save_prediction_csv


def test_save_analysis_csv(tmpdir):
predictions = mock_daily_predictions()
key = (
predictions[0].pair.replace("/", "-")
+ predictions[0].timeframe
+ predictions[0].source
)
save_analysis_csv(predictions, str(tmpdir))

with open(os.path.join(str(tmpdir), key + ".csv")) as f:
data = csv.DictReader(f)
data_rows = list(data)

assert data_rows[0]["Predicted Value"] == str(predictions[0].prediction)
assert data_rows[0]["True Value"] == str(predictions[0].trueval)
assert data_rows[0]["Timestamp"] == str(predictions[0].timestamp)
assert list(data_rows[0].keys()) == [
"PredictionID",
"Timestamp",
"Slot",
"Stake",
"Wallet",
"Payout",
"True Value",
"Predicted Value",
]


def test_save_prediction_csv(tmpdir):
predictions = mock_daily_predictions()
key = (
predictions[0].pair.replace("/", "-")
+ predictions[0].timeframe
+ predictions[0].source
)
save_prediction_csv(predictions, str(tmpdir))

with open(os.path.join(str(tmpdir), key + ".csv")) as f:
data = csv.DictReader(f)
data_rows = list(row for row in data)

assert data_rows[0]["Predicted Value"] == str(predictions[0].prediction)
assert data_rows[0]["True Value"] == str(predictions[0].trueval)
assert data_rows[0]["Timestamp"] == str(predictions[0].timestamp)
assert list(data_rows[0].keys()) == [
"Predicted Value",
"True Value",
"Timestamp",
"Stake",
"Payout",
]

0 comments on commit d205aa5

Please sign in to comment.