diff --git a/AIAgent/run_training.py b/AIAgent/run_training.py index d0992e2..39173f0 100644 --- a/AIAgent/run_training.py +++ b/AIAgent/run_training.py @@ -129,35 +129,30 @@ def model_init(**model_params) -> nn.Module: epochs=training_config.epochs, val_config=validation_config, ) - try: - if optuna_config.study_uri is None and weights_uri is None: - - def save_study(study, _): - joblib.dump( - study, - CURRENT_STUDY_PATH, - ) - with mlflow.start_run(mlflow.last_active_run().info.run_id): - mlflow.log_artifact(CURRENT_STUDY_PATH) - - study = optuna.create_study( - sampler=sampler, direction=optuna_config.study_direction - ) - study.optimize( - objective_partial, - n_trials=optuna_config.n_trials, - gc_after_trial=True, - n_jobs=optuna_config.n_jobs, - callbacks=[save_study], - ) - else: - downloaded_artifact_path = mlflow.artifacts.download_artifacts( - optuna_config.study_uri, dst_path=str(REPORT_PATH) + if optuna_config.study_uri is None and weights_uri is None: + def save_study(study, _): + joblib.dump( + study, + CURRENT_STUDY_PATH, ) - study: optuna.Study = joblib.load(downloaded_artifact_path) - objective_partial(study.best_trial) - except RuntimeError: - logging.error("Fail to train") + with mlflow.start_run(mlflow.last_active_run().info.run_id): + mlflow.log_artifact(CURRENT_STUDY_PATH) + study = optuna.create_study( + sampler=sampler, direction=optuna_config.study_direction + ) + study.optimize( + objective_partial, + n_trials=optuna_config.n_trials, + gc_after_trial=True, + n_jobs=optuna_config.n_jobs, + callbacks=[save_study], + ) + else: + downloaded_artifact_path = mlflow.artifacts.download_artifacts( + optuna_config.study_uri, dst_path=str(REPORT_PATH) + ) + study: optuna.Study = joblib.load(downloaded_artifact_path) + objective_partial(study.best_trial) def objective(