Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some options! #78

Open
wants to merge 38 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
b070a57
simplify the script
cloner174 May 21, 2024
c4aa75f
new Options
cloner174 May 21, 2024
3ae291a
Update data_loader.py
cloner174 May 21, 2024
0ec7b4f
Update exp_long_term_forecasting.py
cloner174 May 21, 2024
ff1d90a
Update data_factory.py
cloner174 May 21, 2024
e08ebc6
Update data_loader.py
cloner174 May 21, 2024
066c872
Update data_loader.py
cloner174 May 21, 2024
854d97c
Update data_loader.py
cloner174 May 21, 2024
98e62b9
add options data_loader.py
cloner174 May 22, 2024
d07c027
add options exp_long_term_forecasting.py
cloner174 May 22, 2024
0c82557
Update LICENSE
cloner174 May 22, 2024
5b57bec
Update LICENSE
cloner174 May 22, 2024
ce974db
Update LICENSE
cloner174 May 22, 2024
1370d35
add options run.py
cloner174 May 22, 2024
b2bbe59
fix data_loader.py
cloner174 May 22, 2024
fd85cbd
save the trues for predict function exp_long_term_forecasting.py
cloner174 May 23, 2024
f618b60
Update exp_long_term_forecasting.py
cloner174 May 23, 2024
bf34f06
Update exp_long_term_forecasting.py
cloner174 May 23, 2024
1b26f1d
fix
cloner174 May 23, 2024
40fe7f7
add empty folder for results during train test val
cloner174 May 23, 2024
45588e8
modify
cloner174 May 23, 2024
8c3b27a
add empty folder for train test val results
cloner174 May 23, 2024
73e22ff
modify
cloner174 May 23, 2024
56ec652
fix
cloner174 May 23, 2024
0c30094
add empty input and 2 root for data
cloner174 May 23, 2024
f97121a
add other loss functions for Model
cloner174 May 23, 2024
9ae3c03
add loss functions access with arg.criter
cloner174 May 23, 2024
6c450f4
add validate info
cloner174 May 24, 2024
60abf1a
predict like a pro
cloner174 May 24, 2024
3db2043
add save_args function
cloner174 May 24, 2024
9fcf3d3
add save_args function
cloner174 May 24, 2024
2f1f766
add save_args class
cloner174 May 24, 2024
ed5c051
dynamic save args
cloner174 May 24, 2024
d38989d
modified: .gitignore
cloner174 May 24, 2024
995832b
fixed
cloner174 Jun 9, 2024
6d95bb3
new updates
cloner174 Jun 11, 2024
ec7ebef
added prediction function
cloner174 Jun 11, 2024
8cbed3d
add prediction fynction after_train.py
cloner174 Jun 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,16 @@ __pycache__/
*.py[cod]
*$py.class

test_results/test_iTransformer_custom_MS_ft15_sl1_ll3_pl512_dm8_nh8_el8_dl1024_df1_fctimeF_ebTrue_dttest_projection_0

results/test_iTransformer_custom_MS_ft15_sl1_ll3_pl512_dm8_nh8_el8_dl1024_df1_fctimeF_ebTrue_dttest_projection_0
result_long_term_forecast.txt
# C extensions
*.so

input/test/data.csv
input/test/scaler.pkl
input/train/data.csv
input/train/scaler.pkl
*/.DS_Store

# Distribution / packaging
Expand Down Expand Up @@ -128,4 +135,4 @@ venv.bak/
dmypy.json

# Pyre type checker
.pyre/
.pyre/
1 change: 1 addition & 0 deletions LICENSE
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
MIT License

Copyright (c) 2022 THUML @ Tsinghua University
Copyright (c) 2024 cloner174 @ Hamed Hajipour

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
Expand Down
Empty file added checkpoints/DONOTREMOVE
Empty file.
65 changes: 42 additions & 23 deletions data_provider/data_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,39 +16,58 @@
def data_provider(args, flag):
Data = data_dict[args.data]
timeenc = 0 if args.embed != 'timeF' else 1

if flag == 'test':
shuffle_flag = False
drop_last = True
batch_size = 1 # bsz=1 for evaluation
freq = args.freq
elif flag == 'pred':
if flag == 'pred' :
shuffle_flag = False
drop_last = False
batch_size = 1
freq = args.freq
Data = Dataset_Pred
data_set = Data(
root_path=args.pred_root_path,
data_path=args.pred_data_path,
flag=flag,
size=[args.seq_len, args.label_len, args.pred_len],
features=args.features,
target=args.target,
timeenc=timeenc,
freq=freq,
kind_of_scaler=args.kind_of_scaler if hasattr(args, 'kind_of_scaler') else 'standard',
name_of_col_with_date = args.name_of_col_with_date if hasattr(args, 'name_of_col_with_date') else 'date',
scale = args.scale if hasattr(args, 'scale') else True,
max_use_of_row = args.max_use_of_row if hasattr(args, 'max_use_of_row') else 'No Lim',
)
print(flag, len(data_set))
else:
shuffle_flag = True
drop_last = True
batch_size = args.batch_size # bsz for train and valid
freq = args.freq

data_set = Data(
root_path=args.root_path,
data_path=args.data_path,
flag=flag,
size=[args.seq_len, args.label_len, args.pred_len],
features=args.features,
target=args.target,
timeenc=timeenc,
freq=freq,
)
print(flag, len(data_set))
if flag == 'test':
shuffle_flag = False
drop_last = True
batch_size = 1 # bsz=1 for evaluation
freq = args.freq
else:
shuffle_flag = True
drop_last = True
batch_size = args.batch_size # bsz for train and valid
freq = args.freq
data_set = Data(
root_path=args.root_path,
data_path=args.data_path,
flag=flag,
size=[args.seq_len, args.label_len, args.pred_len],
features=args.features,
target=args.target,
timeenc=timeenc,
freq=freq,
test_size = args.test_size if hasattr(args, 'test_size') else 0.2,
kind_of_scaler= args.kind_of_scaler if hasattr(args, 'kind_of_scaler') else 'standard',
name_of_col_with_date = args.name_of_col_with_date if hasattr(args, 'name_of_col_with_date') else 'date',
scale = args.scale if hasattr(args, 'scale') else True,
)
print(flag, len(data_set))
data_loader = DataLoader(
data_set,
batch_size=batch_size,
shuffle=shuffle_flag,
num_workers=args.num_workers,
drop_last=drop_last)

return data_set, data_loader
Loading