-
Notifications
You must be signed in to change notification settings - Fork 204
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into feature/anomaly_detection_v2
- Loading branch information
Showing
8 changed files
with
344 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
download_data: | ||
mkdir -p data | ||
curl https://www.datasource.ai/attachments/eyJpZCI6Ijk4NDYxNjE2NmZmZjM0MGRmNmE4MTczOGMyMzI2ZWI2LmNzdiIsInN0b3JhZ2UiOiJzdG9yZSIsIm1ldGFkYXRhIjp7ImZpbGVuYW1lIjoiUGhhc2UgMCAtIFNhbGVzLmNzdiIsInNpemUiOjEwODA0NjU0LCJtaW1lX3R5cGUiOiJ0ZXh0L2NzdiJ9fQ -o data/phase_0_sales.csv | ||
curl https://www.datasource.ai/attachments/eyJpZCI6ImM2OGQxNGNmNTJkZDQ1MTUyZTg0M2FkMDAyMjVlN2NlLmNzdiIsInN0b3JhZ2UiOiJzdG9yZSIsIm1ldGFkYXRhIjp7ImZpbGVuYW1lIjoiUGhhc2UgMSAtIFNhbGVzLmNzdiIsInNpemUiOjEwMTgzOTYsIm1pbWVfdHlwZSI6InRleHQvY3N2In19 -o data/phase_1_sales.csv | ||
curl https://www.datasource.ai/attachments/eyJpZCI6IjhlNmJmNmU3ZTlhNWQ4NTcyNGVhNTI4YjAwNTk3OWE1LmNzdiIsInN0b3JhZ2UiOiJzdG9yZSIsIm1ldGFkYXRhIjp7ImZpbGVuYW1lIjoiUGhhc2UgMiAtIFNhbGVzLmNzdiIsInNpemUiOjEwMTI0MzcsIm1pbWVfdHlwZSI6InRleHQvY3N2In19 -o data/phase_2_sales.csv | ||
curl https://www.datasource.ai/attachments/eyJpZCI6IjI1NDQxYmMyMTQ3MTA0MjJhMDcyYjllODcwZjEyNmY4LmNzdiIsInN0b3JhZ2UiOiJzdG9yZSIsIm1ldGFkYXRhIjp7ImZpbGVuYW1lIjoicGhhc2UgMiBzdWJtaXNzaW9uIGV4YW1pbmUgc21vb3RoZWQgMjAyNDEwMTcgRklOQUwuY3N2Iiwic2l6ZSI6MTk5MzAzNCwibWltZV90eXBlIjoidGV4dC9jc3YifX0 -o data/solution_1st_place.csv | ||
curl https://www.datasource.ai/attachments/eyJpZCI6IjU3ODhjZTUwYTU3MTg3NjFlYzMzOWU0ZTg3MWUzNjQxLmNzdiIsInN0b3JhZ2UiOiJzdG9yZSIsIm1ldGFkYXRhIjp7ImZpbGVuYW1lIjoidm4xX3N1Ym1pc3Npb25fanVzdGluX2Z1cmxvdHRlLmNzdiIsInNpemUiOjM5MDkzNzksIm1pbWVfdHlwZSI6InRleHQvY3N2In19 -o data/solution_2nd_place.csv | ||
curl https://www.datasource.ai/attachments/eyJpZCI6ImE5NzcwNTZhMzhhMTc2ZWJjODFkMDMwMTM2Y2U2MTdlLmNzdiIsInN0b3JhZ2UiOiJzdG9yZSIsIm1ldGFkYXRhIjp7ImZpbGVuYW1lIjoiYXJzYW5pa3phZF9zdWIuY3N2Iiwic2l6ZSI6Mzg4OTcyNCwibWltZV90eXBlIjoidGV4dC9jc3YifX0 -o data/solution_3rd_place.csv | ||
curl https://www.datasource.ai/attachments/eyJpZCI6ImVlZmUxYWY2NDFjOWMwM2IxMzRhZTc2MzI1Nzg3NzIxLmNzdiIsInN0b3JhZ2UiOiJzdG9yZSIsIm1ldGFkYXRhIjp7ImZpbGVuYW1lIjoiVEZUX3R1bmVkX1YyX3NlZWRfNDIuY3N2Iiwic2l6ZSI6NjA3NDgzLCJtaW1lX3R5cGUiOiJ0ZXh0L2NzdiJ9fQ -o data/solution_4th_place.csv | ||
curl https://www.datasource.ai/attachments/eyJpZCI6IjMwMDEwMmY3NTNhMzlhN2YxNTk3ODYxZTI1N2Q2NzRmLmNzdiIsInN0b3JhZ2UiOiJzdG9yZSIsIm1ldGFkYXRhIjp7ImZpbGVuYW1lIjoiZGl2aW5lb3B0aW1pemVkd2VpZ2h0c2Vuc2VtYmxlLmNzdiIsInNpemUiOjE3OTU0NzgsIm1pbWVfdHlwZSI6InRleHQvY3N2In19 -o data/solution_5th_place.csv |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
# Score 2nd place in the VN1 Challenge with a few lines of code in under 10 seconds using TimeGPT | ||
|
||
We present a fully reproducible experiment demonstrating that Nixtla's **TimeGPT** can achieve the **2nd position** in the [VN1 Forecasting Accuracy Challenge](https://www.datasource.ai/en/home/data-science-competitions-for-startups/phase-2-vn1-forecasting-accuracy-challenge/description) with **zero-shot forecasting**. This result was achieved using the zero-shot capabilities of the foundation model, as most of the code focuses on **data cleaning and preprocessing**, not model training or parameter tuning. | ||
|
||
The table below showcases the official competition results, with TimeGPT outperforming the 2nd, 3rd, and other models in the competition. | ||
|
||
| **Model** | **Score** | | ||
| ----------- | ---------- | | ||
| 1st | 0.4637 | | ||
| **TimeGPT** | **0.4651** | | ||
| 2nd | 0.4657 | | ||
| 3rd | 0.4758 | | ||
| 4th | 0.4774 | | ||
| 5th | 0.4808 | | ||
|
||
--- | ||
|
||
## **Introduction** | ||
|
||
The VN1 Forecasting Accuracy Challenge tasked participants with forecasting future sales using historical sales and pricing data. The goal was to develop robust predictive models capable of anticipating sales trends for various products across different clients and warehouses. Submissions were evaluated based on their accuracy and bias against actual sales figures. | ||
|
||
The competition was structured into two phases: | ||
|
||
- **Phase 1** (September 12 - October 3, 2024): Participants used the provided Phase 0 sales data to predict sales for Phase 1. This phase lasted three weeks and featured live leaderboard updates to track participant progress. | ||
- **Phase 2** (October 3 - October 17, 2024): Participants utilized both Phase 0 and Phase 1 data to predict sales for Phase 2. Unlike Phase 1, there were no leaderboard updates during this phase until the competition concluded. | ||
|
||
One of the competition's key requirements was to use **open-source solutions**. However, as TimeGPT works through an API, we did not upload the forecasts generated during the competition. Instead, we showcase the effectiveness of TimeGPT by presenting the results of our approach. | ||
|
||
Our approach leverages the power of **zero-shot forecasting**, where no training, fine-tuning, or manual hyperparameter adjustments are needed. We used only **historical sales data** without any exogenous variables to generate forecasts. With this setting, TimeGPT provides forecasts that achieve an accuracy surpassing nearly all competitors. | ||
|
||
Remarkably, the process required only **5 seconds of inference time**, demonstrating the efficiency of TimeGPT. | ||
|
||
--- | ||
|
||
## **Empirical Evaluation** | ||
|
||
This study considers time series from multiple datasets provided during the competition. Unlike most competitors, we do not train, fine-tune, or manually adjust TimeGPT. Instead, we rely on **zero-shot learning** to forecast the time series directly. | ||
|
||
This study contrasts TimeGPT's zero-shot forecasts against the top 1st, 2nd, and 3rd models submitted to the competition. Our evaluation method follows the official rules and metrics of the VN1 competition. | ||
|
||
An R version of this study is also available via `nixtlar`, a CRAN package that provides an interface to Nixtla's TimeGPT. | ||
--- | ||
|
||
## **Results** | ||
|
||
The table below summarizes the official competition results. Despite using a zero-shot approach, **TimeGPT achieves the 2nd position** with a score of **0.4651**, outperforming the models ranked 2nd and 3rd. | ||
|
||
| **Model** | **Score** | | ||
| ----------- | ---------- | | ||
| 1st | 0.4637 | | ||
| **TimeGPT** | **0.4651** | | ||
| 2nd | 0.4657 | | ||
| 3rd | 0.4758 | | ||
| 4th | 0.4774 | | ||
| 5th | 0.4808 | | ||
|
||
--- | ||
|
||
## **Reproducibility** | ||
|
||
All necessary code and detailed instructions for reproducing the experiment are available in this repository. | ||
|
||
### **Instructions** | ||
|
||
1. **Get an API Key** from the [Nixtla Dashboard](https://dashboard.nixtla.io/). Copy it and paste it into the `.env.example` file. Rename the file to `.env`. | ||
|
||
2. **Set up [uv](https://github.com/astral-sh/uv):** | ||
|
||
```bash | ||
pip install uv | ||
uv venv --python 3.10 | ||
source .venv/bin/activate | ||
uv pip sync requirements.txt | ||
``` | ||
|
||
3. **Download data:** | ||
|
||
```bash | ||
make download_data | ||
``` | ||
|
||
4. **Run the complete pipeline:** | ||
|
||
```bash | ||
python -m src.main | ||
``` | ||
|
||
5. **Tests** | ||
|
||
We made sure that the results are comparable by comparing the results against the [official competition results](https://www.datasource.ai/en/home/data-science-competitions-for-startups/phase-2-vn1-forecasting-accuracy-challenge/leaderboard). You can run the tests using: | ||
|
||
```bash | ||
pytest | ||
``` | ||
|
||
6. **R results:** | ||
For the R version of this study using `nixtlar`, run the `main.R` script. Make sure the `functions.R` script is in the same directory. | ||
--- | ||
|
||
## **References** | ||
|
||
- Vandeput, Nicolas. “VN1 Forecasting - Accuracy Challenge.” DataSource.ai, DataSource, 3 Oct. 2024, [https://www.datasource.ai/en/home/data-science-competitions-for-startups/phase-2-vn1-forecasting-accuracy-challenge/description](https://www.datasource.ai/en/home/data-science-competitions-for-startups/phase-2-vn1-forecasting-accuracy-challenge/description) | ||
- [TimeGPT Paper](https://arxiv.org/abs/2310.03589) | ||
|
||
|
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
|
||
# Functions for VN1 Forecasting Competition ---- | ||
|
||
read_and_prepare_data <- function(dataset){ | ||
# Reads data in wide format and returns it in long format with columns `unique_id`, `ds`, and `y` | ||
url <- get_dataset_url(dataset) | ||
df_wide <- fread(url) | ||
df_wide <- df_wide |> | ||
mutate(unique_id = paste0(Client, "/", Warehouse, "/", Product)) |> | ||
select(c(unique_id, everything())) |> | ||
select(-c(Client, Warehouse, Product)) | ||
|
||
df <- pivot_longer( | ||
data = df_wide, | ||
cols = -unique_id, | ||
names_to = "ds", | ||
values_to = "y" | ||
) | ||
|
||
if(startsWith(dataset, "winners")){ | ||
names(df)[which(names(df) == "y")] <- dataset | ||
} | ||
|
||
return(df) | ||
} | ||
|
||
get_train_data <- function(df0, df1){ | ||
# Merges training data from phase 0 and phase 1 and removes leading zeros | ||
df <- rbind(df0, df1) |> | ||
arrange(unique_id, ds) | ||
|
||
df_clean <- df |> | ||
group_by(unique_id) |> | ||
mutate(cumsum = cumsum(y)) |> | ||
filter(cumsum > 0) |> | ||
select(-cumsum) |> | ||
ungroup() | ||
|
||
return(df_clean) | ||
} | ||
|
||
vn1_competition_evaluation <- function(test, forecast, model){ | ||
# Computes competition evaluation | ||
if(!is.character(forecast$ds)){ | ||
forecast$ds <- as.character(forecast$ds) # nixtlar returns timestamps for plotting | ||
} | ||
|
||
res <- merge(forecast, test, by=c("unique_id", "ds")) | ||
|
||
res <- res |> | ||
mutate(abs_err = abs(res[[model]]-res$y)) |> | ||
mutate(err = res[[model]]-res$y) | ||
|
||
abs_err = sum(res$abs_err, na.rm = TRUE) | ||
err = sum(res$err, na.rm = TRUE) | ||
score = abs_err+abs(err) | ||
score = score/sum(res$y) | ||
score = round(score, 4) | ||
|
||
return(score) | ||
} | ||
|
||
get_dataset_url <- function(dataset){ | ||
# Returns the url of the given competition dataset | ||
urls <- list( | ||
phase0_sales = "https://www.datasource.ai/attachments/eyJpZCI6Ijk4NDYxNjE2NmZmZjM0MGRmNmE4MTczOGMyMzI2ZWI2LmNzdiIsInN0b3JhZ2UiOiJzdG9yZSIsIm1ldGFkYXRhIjp7ImZpbGVuYW1lIjoiUGhhc2UgMCAtIFNhbGVzLmNzdiIsInNpemUiOjEwODA0NjU0LCJtaW1lX3R5cGUiOiJ0ZXh0L2NzdiJ9fQ", | ||
phase1_sales = "https://www.datasource.ai/attachments/eyJpZCI6ImM2OGQxNGNmNTJkZDQ1MTUyZTg0M2FkMDAyMjVlN2NlLmNzdiIsInN0b3JhZ2UiOiJzdG9yZSIsIm1ldGFkYXRhIjp7ImZpbGVuYW1lIjoiUGhhc2UgMSAtIFNhbGVzLmNzdiIsInNpemUiOjEwMTgzOTYsIm1pbWVfdHlwZSI6InRleHQvY3N2In19", | ||
phase2_sales = "https://www.datasource.ai/attachments/eyJpZCI6IjhlNmJmNmU3ZTlhNWQ4NTcyNGVhNTI4YjAwNTk3OWE1LmNzdiIsInN0b3JhZ2UiOiJzdG9yZSIsIm1ldGFkYXRhIjp7ImZpbGVuYW1lIjoiUGhhc2UgMiAtIFNhbGVzLmNzdiIsInNpemUiOjEwMTI0MzcsIm1pbWVfdHlwZSI6InRleHQvY3N2In19", | ||
winners1 = "https://www.datasource.ai/attachments/eyJpZCI6IjI1NDQxYmMyMTQ3MTA0MjJhMDcyYjllODcwZjEyNmY4LmNzdiIsInN0b3JhZ2UiOiJzdG9yZSIsIm1ldGFkYXRhIjp7ImZpbGVuYW1lIjoicGhhc2UgMiBzdWJtaXNzaW9uIGV4YW1pbmUgc21vb3RoZWQgMjAyNDEwMTcgRklOQUwuY3N2Iiwic2l6ZSI6MTk5MzAzNCwibWltZV90eXBlIjoidGV4dC9jc3YifX0", | ||
winners2 = "https://www.datasource.ai/attachments/eyJpZCI6IjU3ODhjZTUwYTU3MTg3NjFlYzMzOWU0ZTg3MWUzNjQxLmNzdiIsInN0b3JhZ2UiOiJzdG9yZSIsIm1ldGFkYXRhIjp7ImZpbGVuYW1lIjoidm4xX3N1Ym1pc3Npb25fanVzdGluX2Z1cmxvdHRlLmNzdiIsInNpemUiOjM5MDkzNzksIm1pbWVfdHlwZSI6InRleHQvY3N2In19", | ||
winners3 = "https://www.datasource.ai/attachments/eyJpZCI6ImE5NzcwNTZhMzhhMTc2ZWJjODFkMDMwMTM2Y2U2MTdlLmNzdiIsInN0b3JhZ2UiOiJzdG9yZSIsIm1ldGFkYXRhIjp7ImZpbGVuYW1lIjoiYXJzYW5pa3phZF9zdWIuY3N2Iiwic2l6ZSI6Mzg4OTcyNCwibWltZV90eXBlIjoidGV4dC9jc3YifX0", | ||
winners4 = "https://www.datasource.ai/attachments/eyJpZCI6ImVlZmUxYWY2NDFjOWMwM2IxMzRhZTc2MzI1Nzg3NzIxLmNzdiIsInN0b3JhZ2UiOiJzdG9yZSIsIm1ldGFkYXRhIjp7ImZpbGVuYW1lIjoiVEZUX3R1bmVkX1YyX3NlZWRfNDIuY3N2Iiwic2l6ZSI6NjA3NDgzLCJtaW1lX3R5cGUiOiJ0ZXh0L2NzdiJ9fQ", | ||
winners5 = "https://www.datasource.ai/attachments/eyJpZCI6IjMwMDEwMmY3NTNhMzlhN2YxNTk3ODYxZTI1N2Q2NzRmLmNzdiIsInN0b3JhZ2UiOiJzdG9yZSIsIm1ldGFkYXRhIjp7ImZpbGVuYW1lIjoiZGl2aW5lb3B0aW1pemVkd2VpZ2h0c2Vuc2VtYmxlLmNzdiIsInNpemUiOjE3OTU0NzgsIm1pbWVfdHlwZSI6InRleHQvY3N2In19" | ||
) | ||
|
||
return(urls[[dataset]]) | ||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
|
||
# VN1 Forecasting Competition Solution with nixtlar ---- | ||
|
||
install.packages(c("nixtlar", "tidyverse", "data.table")) | ||
|
||
library(nixtlar) | ||
library(tidyverse) | ||
library(data.table) | ||
|
||
source("functions.R") # same directory as main.R | ||
|
||
## Load Data ---- | ||
sales0 <- read_and_prepare_data("phase0_sales") | ||
sales1 <- read_and_prepare_data("phase1_sales") | ||
test_df <- read_and_prepare_data("phase2_sales") | ||
|
||
## Prepare Training Dataset ---- | ||
train_df <- get_train_data(sales0, sales1) | ||
|
||
## Generate TimeGPT Forecast ---- | ||
|
||
# nixtla_client_setup(api_key = "Your API key here") | ||
# Learn how to set up your API key here: https://nixtla.github.io/nixtlar/articles/setting-up-your-api-key.html | ||
|
||
fc <- nixtla_client_forecast(train_df, h=13, model="timegpt-1-long-horizon") | ||
|
||
## Visualize TimeGPT Forecast ---- | ||
nixtla_client_plot(train_df, fc) | ||
|
||
## Evaluate TimeGPT & Top 5 Competition Solutions ---- | ||
timegpt_score <- vn1_competition_evaluation(test_df, fc, "TimeGPT") | ||
|
||
scores <- lapply(1:5, function(i){ # Top 5 | ||
winner_df <- read_and_prepare_data(paste0("winners", i)) | ||
vn1_competition_evaluation(test_df, winner_df, model = paste0("winners", i)) | ||
}) | ||
|
||
scores_df <- data.frame( | ||
"Result" = c(paste0("Place #", 1:5), "TimeGPT"), | ||
"Score" = c(as.numeric(scores), timegpt_score) | ||
) | ||
|
||
scores_df <- scores_df |> arrange(Score) | ||
print(scores_df) # TimeGPT places 2nd! |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
from time import time | ||
|
||
|
||
import numpy as np | ||
import pandas as pd | ||
from dotenv import load_dotenv | ||
from nixtla import NixtlaClient | ||
|
||
load_dotenv() | ||
|
||
|
||
def read_and_prepare_data(file_path: str, value_name: str = "y") -> pd.DataFrame: | ||
"""Reads data in wide format, and returns it in long format with columns `unique_id`, `ds`, `y`""" | ||
df = pd.read_csv(file_path) | ||
uid_cols = ["Client", "Warehouse", "Product"] | ||
df["unique_id"] = df[uid_cols].astype(str).agg("-".join, axis=1) | ||
df = df.drop(uid_cols, axis=1) | ||
df = df.melt(id_vars=["unique_id"], var_name="ds", value_name=value_name) | ||
df["ds"] = pd.to_datetime(df["ds"]) | ||
df = df.sort_values(by=["unique_id", "ds"]) | ||
return df | ||
|
||
|
||
def get_train_data() -> pd.DataFrame: | ||
"""Reads all train data and returns it in long format with columns `unique_id`, `ds`, `y`""" | ||
train_list = [read_and_prepare_data(f"./data/phase_{i}_sales.csv") for i in [0, 1]] | ||
train_df = pd.concat(train_list).reset_index(drop=True) | ||
train_df = train_df.sort_values(by=["unique_id", "ds"]) | ||
|
||
def remove_leading_zeros(group): | ||
first_non_zero_index = group["y"].ne(0).idxmax() | ||
return group.loc[first_non_zero_index:] | ||
|
||
train_df = ( | ||
train_df.groupby("unique_id").apply(remove_leading_zeros).reset_index(drop=True) | ||
) | ||
return train_df | ||
|
||
|
||
def get_competition_forecasts() -> pd.DataFrame: | ||
"""Reads all competition forecasts and returns it in long format with columns `unique_id`, `ds`, `y`""" | ||
fcst_df: pd.DataFrame | None = None | ||
for place in ["1st", "2nd", "3rd", "4th", "5th"]: | ||
fcst_df_place = read_and_prepare_data( | ||
f"./data/solution_{place}_place.csv", place | ||
) | ||
if fcst_df is None: | ||
fcst_df = fcst_df_place | ||
else: | ||
fcst_df = fcst_df.merge( | ||
fcst_df_place, | ||
on=["unique_id", "ds"], | ||
how="left", | ||
) | ||
return fcst_df | ||
|
||
|
||
def vn1_competition_evaluation(forecasts: pd.DataFrame) -> pd.DataFrame: | ||
"""Computes competition evaluation scores""" | ||
actual = read_and_prepare_data("./data/phase_2_sales.csv") | ||
res = actual[["unique_id", "ds", "y"]].merge( | ||
forecasts, on=["unique_id", "ds"], how="left" | ||
) | ||
ids_forecasts = forecasts["unique_id"].unique() | ||
ids_res = res["unique_id"].unique() | ||
assert set(ids_forecasts) == set(ids_res), "Some unique_ids are missing" | ||
scores = {} | ||
for model in [col for col in forecasts.columns if col not in ["unique_id", "ds"]]: | ||
abs_err = np.nansum(np.abs(res[model] - res["y"])) | ||
err = np.nansum(res[model] - res["y"]) | ||
score = abs_err + abs(err) | ||
score = score / res["y"].sum() | ||
scores[model] = round(score, 4) | ||
score_df = pd.DataFrame(list(scores.items()), columns=["model", "score"]) | ||
score_df = score_df.sort_values(by="score") | ||
return score_df | ||
|
||
|
||
def main(): | ||
"""Complete pipeline""" | ||
train_df = get_train_data() | ||
client = NixtlaClient() | ||
init = time() | ||
fcst_df = client.forecast(train_df, h=13, model="timegpt-1-long-horizon") | ||
print(f"TimeGPT time: {time() - init}") | ||
fcst_df_comp = get_competition_forecasts() | ||
fcst_df = fcst_df.merge(fcst_df_comp, on=["unique_id", "ds"], how="left") | ||
eval_df = vn1_competition_evaluation(fcst_df) | ||
print(eval_df) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
import pandas as pd | ||
|
||
from src.main import vn1_competition_evaluation, get_competition_forecasts | ||
|
||
|
||
def test_vn1_competition_evaluation(): | ||
forecasts = get_competition_forecasts() | ||
eval_df = vn1_competition_evaluation(forecasts) | ||
assert len(eval_df) == 5 | ||
pd.testing.assert_series_equal( | ||
eval_df["score"], | ||
pd.Series([0.4637, 0.4657, 0.4758, 0.4774, 0.4808]), | ||
check_names=False, | ||
) |