-
Notifications
You must be signed in to change notification settings - Fork 169
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Integrate Ray tune with the base trainer class #86
Comments
Hello @shrave, Thank you for kind words. I am happy to hear that this repo is useful for your research. As to the issue, I have never really used
from pythae.trainers.training_callbacks import TrainingCallback
class RayCallback(TrainingCallback):
def __init__(self) -> None:
super().__init__()
def on_epoch_end(self, training_config: BaseTrainerConfig, **kwargs):
metrics = kwargs.pop("metrics") # get the metrics during training
tune.report(eval_epoch_loss=metrics["eval_epoch_loss"]) # add the metric to monitor in the report
def train_ray(config):
mnist_trainset = datasets.MNIST(root='../../data', train=True, download=True, transform=None)
train_dataset = BaseDataset(mnist_trainset.data[:1000].reshape(-1, 1, 28, 28) / 255., torch.ones(1000))
eval_dataset = BaseDataset(mnist_trainset.data[-1000:].reshape(-1, 1, 28, 28) / 255., torch.ones(1000))
my_training_config = BaseTrainerConfig(
output_dir='my_model',
num_epochs=50,
learning_rate=config["lr"], # pass the lr for hp search
per_device_train_batch_size=200,
per_device_eval_batch_size=200,
steps_saving=None,
optimizer_cls="AdamW",
optimizer_params={"weight_decay": 0.05, "betas": (0.91, 0.995)},
scheduler_cls="ReduceLROnPlateau",
scheduler_params={"patience": 5, "factor": 0.5}
)
my_vae_config = model_config = VAEConfig(
input_dim=(1, 28, 28),
latent_dim=10
)
my_vae_model = VAE(
model_config=my_vae_config
)
# Add the ray callback to the callback list
callbacks = [RayCallback()]
trainer = BaseTrainer(
my_vae_model,
train_dataset,
eval_dataset,
my_training_config,
callbacks=callbacks # pass the callbacks to the trainer
)
trainer.train() # launch the training
search_space = {
"lr": tune.sample_from(lambda spec: 10 ** (-10 * np.random.rand())),
}
tuner = tune.Tuner(
train_ray,
tune_config=tune.TuneConfig(
num_samples=20,
scheduler=ASHAScheduler(metric="eval_epoch_loss", mode="min"),
),
param_space=search_space,
)
results = tuner.fit() I have opened #87 since some minor changes should be added to the current implementation of the Do not hesitate, if you have any questions. Best, Clément |
Hi,
I was wondering if I could include the ray tune (hyper-parameter search) library as either a callback or in the base trainer class to look for the right hyper-parameters for a model and even stop early.
Can you please tell me how it is possible to integrate it and thereby stop the training midway in case a particular hyper-parameter configuration does not give good performance?
Even if you could suggest a way to just return the logger at every epoch when a training pipeline instance is called, then my job would be done.
This library has been extremely useful in my research. Thank you very much!
The text was updated successfully, but these errors were encountered: