Skip to content

Commit

Permalink
1.Add script to run hpo; 2. update version to 0.0.5
Browse files Browse the repository at this point in the history
  • Loading branch information
alilevy committed Jul 12, 2023
1 parent 132b6a7 commit d80be97
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 1 deletion.
55 changes: 55 additions & 0 deletions examples/configs/hpo_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
pipeline_config_id: hpo_runner_config

data:
taxi:
data_format: pkl
train_dir: ./data/taxi/train.pkl
valid_dir: ./data/taxi/dev.pkl
test_dir: ./data/taxi/test.pkl
data_specs:
num_event_types: 10
pad_token_id: 10
padding_side: right
truncation_side: right

hpo:
storage_uri: sqlite://hpo_test.db
is_continuous: False
framework_id: optuna # the framework of hpo
n_trials: 10


NHP_train:
base_config:
stage: train
backend: torch
dataset_id: taxi
runner_id: std_tpp
model_id: NHP # model name
base_dir: './checkpoints/'
trainer_config:
batch_size: 256
max_epoch: 200
shuffle: False
optimizer: adam
learning_rate: 1.e-3
valid_freq: 1
use_tfb: False
metrics: [ 'acc', 'rmse' ]
seed: 2019
gpu: -1
model_config:
hidden_size: 64
loss_integral_num_sample_per_step: 20
# pretrained_model_dir: ./checkpoints/75518_4377527680_230530-132355/models/saved_model
thinning:
num_seq: 10
num_sample: 1
num_exp: 500 # number of i.i.d. Exp(intensity_bound) draws at one time in thinning algorithm
look_ahead_time: 10
patience_counter: 5 # the maximum iteration used in adaptive thinning
over_sample_rate: 5
num_samples_boundary: 5
dtime_max: 5
num_step_gen: 1

26 changes: 26 additions & 0 deletions examples/train_nhp_hpo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import argparse

from easy_tpp.config_factory import Config
from easy_tpp.hpo import HyperTuner


def main():
parser = argparse.ArgumentParser()

parser.add_argument('--config_dir', type=str, required=False, default='configs/hpo_config.yaml',
help='Dir of configuration yaml to train and evaluate the model.')

parser.add_argument('--experiment_id', type=str, required=False, default='NHP_train',
help='Experiment id in the config file.')

args = parser.parse_args()

config = Config.build_from_yaml_file(args.config_dir, experiment_id=args.experiment_id)

tuner = HyperTuner.build_from_config(config)

tuner.run()


if __name__ == '__main__':
main()
2 changes: 1 addition & 1 deletion version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.0.4'
__version__ = '0.0.5'

0 comments on commit d80be97

Please sign in to comment.