From c9b95fa145a1acdfa209ae4913939ab614adc6e5 Mon Sep 17 00:00:00 2001 From: Yoshitomo Matsubara Date: Wed, 31 Aug 2022 22:46:32 -0700 Subject: [PATCH] replaced torch.no_grad() with torch.inference_mode() --- examples/hf_transformers/text_classification.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/hf_transformers/text_classification.py b/examples/hf_transformers/text_classification.py index 9416cef9..ea357b1a 100644 --- a/examples/hf_transformers/text_classification.py +++ b/examples/hf_transformers/text_classification.py @@ -124,7 +124,7 @@ def train_one_epoch(training_box, epoch, log_freq): metric_logger.meters['sample/s'].update(batch_size / (time.time() - start_time)) -@torch.no_grad() +@torch.inference_mode() def evaluate(model, data_loader, metric, is_regression, accelerator, title=None, header='Test: '): if title is not None: logger.info(title) @@ -170,7 +170,7 @@ def train(teacher_model, student_model, dataset_dict, is_regression, ckpt_dir_pa training_box.post_process() -@torch.no_grad() +@torch.inference_mode() def predict_private(model, dataset_dict, label_names_dict, is_regression, accelerator, private_configs, private_output_dir_path): logger.info('Start prediction for private dataset(s)')