Skip to content

Commit

Permalink
improve pretrain
Browse files Browse the repository at this point in the history
  • Loading branch information
Dong Zhou authored and you-n-g committed Jul 30, 2021
1 parent 5b7b48e commit a7c41b6
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 5 deletions.
3 changes: 2 additions & 1 deletion examples/benchmarks/TRA/workflow_config_tra_Alpha158.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ task:
kwargs:
tra_config: *tra_config
model_config: *model_config
model_type: RNN
lr: 1e-3
n_epochs: 100
max_steps_per_epoch: 100
Expand All @@ -86,7 +87,7 @@ task:
memory_mode: *memory_mode
eval_train: False
eval_test: True
pretrain: False
pretrain: True
init_state:
freeze_model: False
freeze_predictors: False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,13 @@ tra_config: &tra_config
num_states: *num_states
hidden_size: 16
tau: 1.0
src_info: TPE
src_info: LR_TPE

model_config: &model_config
input_size: 158
hidden_size: 256
num_layers: 2
rnn_arch: LSTM
use_attn: True
dropout: 0.2

Expand Down Expand Up @@ -66,6 +67,7 @@ task:
kwargs:
tra_config: *tra_config
model_config: *model_config
model_type: RNN
lr: 1e-3
n_epochs: 100
max_steps_per_epoch: 100
Expand All @@ -79,7 +81,7 @@ task:
memory_mode: *memory_mode
eval_train: False
eval_test: True
pretrain: False
pretrain: True
init_state:
freeze_model: False
freeze_predictors: False
Expand Down
3 changes: 2 additions & 1 deletion examples/benchmarks/TRA/workflow_config_tra_Alpha360.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ task:
kwargs:
tra_config: *tra_config
model_config: *model_config
model_type: RNN
lr: 1e-3
n_epochs: 100
max_steps_per_epoch: 100
Expand All @@ -80,7 +81,7 @@ task:
memory_mode: *memory_mode
eval_train: False
eval_test: True
pretrain: False
pretrain: True
init_state:
freeze_model: False
freeze_predictors: False
Expand Down
2 changes: 1 addition & 1 deletion qlib/contrib/model/pytorch_tra.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ def fit(self, dataset, evals_result=dict()):
self.logger.info("reset TRA")
self.tra.reset_parameters() # reset both router and predictors

self.optimizer = optim.Adam(self.tra.parameters(), lr=self.lr) # optimize TRA only
self.optimizer = optim.Adam(list(self.model.parameters()) + list(self.tra.parameters()), lr=self.lr)

self.logger.info("training...")
best_score, _ = self._fit(train_set, valid_set, test_set, evals_result, start_epoch=epoch, is_pretrain=False)
Expand Down

0 comments on commit a7c41b6

Please sign in to comment.