From 8795dc2d4e1d64457fadd46523a3f452c6a10cd7 Mon Sep 17 00:00:00 2001 From: Vinicius Reis Date: Mon, 2 Mar 2020 09:56:35 -0800 Subject: [PATCH 1/2] Remove local_variables from on_step Summary: local_variables makes the code in train_step really hard to read. Killing it from all hooks will take time, so start from a single hook (on_step). Differential Revision: D20171981 fbshipit-source-id: 79342642cbac9a8ebcc9ca59a2b7cce8d4f64f14 --- classy_vision/hooks/classy_hook.py | 6 ++--- .../exponential_moving_average_model_hook.py | 2 +- .../hooks/loss_lr_meter_logging_hook.py | 20 +++++--------- classy_vision/hooks/progress_bar_hook.py | 4 +-- classy_vision/hooks/tensorboard_plot_hook.py | 4 +-- classy_vision/hooks/time_metrics_hook.py | 20 ++++++-------- classy_vision/tasks/classification_task.py | 27 ++++++++++++++++++- classy_vision/tasks/classy_task.py | 3 ++- ...onential_moving_average_model_hook_test.py | 2 +- test/hooks_loss_lr_meter_logging_hook_test.py | 16 +++++------ test/hooks_time_metrics_hook_test.py | 12 ++++----- test/manual/hooks_progress_bar_hook_test.py | 8 +++--- .../hooks_tensorboard_plot_hook_test.py | 4 +-- test/optim_param_scheduler_test.py | 2 +- 14 files changed, 70 insertions(+), 60 deletions(-) diff --git a/classy_vision/hooks/classy_hook.py b/classy_vision/hooks/classy_hook.py index 10fca71069..6ee11a5e0c 100644 --- a/classy_vision/hooks/classy_hook.py +++ b/classy_vision/hooks/classy_hook.py @@ -51,7 +51,7 @@ def __init__(self, a, b): def __init__(self): self.state = ClassyHookState() - def _noop(self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]) -> None: + def _noop(self, *args, **kwargs) -> None: """Derived classes can set their hook functions to this. This is useful if they want those hook functions to not do anything. @@ -79,9 +79,7 @@ def on_phase_start( pass @abstractmethod - def on_step( - self, task: "tasks.ClassyTask", local_variables: Dict[str, Any] - ) -> None: + def on_step(self, task: "tasks.ClassyTask") -> None: """Called each time after parameters have been updated by the optimizer.""" pass diff --git a/classy_vision/hooks/exponential_moving_average_model_hook.py b/classy_vision/hooks/exponential_moving_average_model_hook.py index d0c72a35ae..409cf0e910 100644 --- a/classy_vision/hooks/exponential_moving_average_model_hook.py +++ b/classy_vision/hooks/exponential_moving_average_model_hook.py @@ -103,7 +103,7 @@ def on_phase_end(self, task: ClassyTask, local_variables: Dict[str, Any]) -> Non # state in the test phase self._save_current_model_state(task.base_model, self.state.model_state) - def on_step(self, task: ClassyTask, local_variables: Dict[str, Any]) -> None: + def on_step(self, task: ClassyTask) -> None: if not task.train: return diff --git a/classy_vision/hooks/loss_lr_meter_logging_hook.py b/classy_vision/hooks/loss_lr_meter_logging_hook.py index 33cb283024..9de42d55fa 100644 --- a/classy_vision/hooks/loss_lr_meter_logging_hook.py +++ b/classy_vision/hooks/loss_lr_meter_logging_hook.py @@ -45,13 +45,11 @@ def on_phase_end( # trainer to implement an unsynced end of phase meter or # for meters to not provide a sync function. logging.info("End of phase metric values:") - self._log_loss_meters(task, local_variables) + self._log_loss_meters(task) if task.train: - self._log_lr(task, local_variables) + self._log_lr(task) - def on_step( - self, task: "tasks.ClassyTask", local_variables: Dict[str, Any] - ) -> None: + def on_step(self, task: "tasks.ClassyTask") -> None: """ Log the LR every log_freq batches, if log_freq is not None. """ @@ -59,22 +57,18 @@ def on_step( return batches = len(task.losses) if batches and batches % self.log_freq == 0: - self._log_lr(task, local_variables) + self._log_lr(task) logging.info("Local unsynced metric values:") - self._log_loss_meters(task, local_variables) + self._log_loss_meters(task) - def _log_lr( - self, task: "tasks.ClassyTask", local_variables: Dict[str, Any] - ) -> None: + def _log_lr(self, task: "tasks.ClassyTask") -> None: """ Compute and log the optimizer LR. """ optimizer_lr = task.optimizer.parameters.lr logging.info("Learning Rate: {}\n".format(optimizer_lr)) - def _log_loss_meters( - self, task: "tasks.ClassyTask", local_variables: Dict[str, Any] - ) -> None: + def _log_loss_meters(self, task: "tasks.ClassyTask") -> None: """ Compute and log the loss and meters. """ diff --git a/classy_vision/hooks/progress_bar_hook.py b/classy_vision/hooks/progress_bar_hook.py index bdecbc6c9f..fc5dfde405 100644 --- a/classy_vision/hooks/progress_bar_hook.py +++ b/classy_vision/hooks/progress_bar_hook.py @@ -51,9 +51,7 @@ def on_phase_start( self.progress_bar = progressbar.ProgressBar(self.bar_size) self.progress_bar.start() - def on_step( - self, task: "tasks.ClassyTask", local_variables: Dict[str, Any] - ) -> None: + def on_step(self, task: "tasks.ClassyTask") -> None: """Update the progress bar with the batch size.""" if task.train and is_master() and self.progress_bar is not None: self.batches += 1 diff --git a/classy_vision/hooks/tensorboard_plot_hook.py b/classy_vision/hooks/tensorboard_plot_hook.py index e86864e313..6abeec04ba 100644 --- a/classy_vision/hooks/tensorboard_plot_hook.py +++ b/classy_vision/hooks/tensorboard_plot_hook.py @@ -64,9 +64,7 @@ def on_phase_start( self.wall_times = [] self.num_steps_global = [] - def on_step( - self, task: "tasks.ClassyTask", local_variables: Dict[str, Any] - ) -> None: + def on_step(self, task: "tasks.ClassyTask") -> None: """Store the observed learning rates.""" if self.learning_rates is None: logging.warning("learning_rates is not initialized") diff --git a/classy_vision/hooks/time_metrics_hook.py b/classy_vision/hooks/time_metrics_hook.py index 618e3fbd35..475282a699 100644 --- a/classy_vision/hooks/time_metrics_hook.py +++ b/classy_vision/hooks/time_metrics_hook.py @@ -40,11 +40,9 @@ def on_phase_start( Initialize start time and reset perf stats """ self.start_time = time.time() - local_variables["perf_stats"] = PerfStats() + task.perf_stats = PerfStats() - def on_step( - self, task: "tasks.ClassyTask", local_variables: Dict[str, Any] - ) -> None: + def on_step(self, task: "tasks.ClassyTask") -> None: """ Log metrics every log_freq batches, if log_freq is not None. """ @@ -52,7 +50,7 @@ def on_step( return batches = len(task.losses) if batches and batches % self.log_freq == 0: - self._log_performance_metrics(task, local_variables) + self._log_performance_metrics(task) def on_phase_end( self, task: "tasks.ClassyTask", local_variables: Dict[str, Any] @@ -62,11 +60,9 @@ def on_phase_end( """ batches = len(task.losses) if batches: - self._log_performance_metrics(task, local_variables) + self._log_performance_metrics(task) - def _log_performance_metrics( - self, task: "tasks.ClassyTask", local_variables: Dict[str, Any] - ) -> None: + def _log_performance_metrics(self, task: "tasks.ClassyTask") -> None: """ Compute and log performance metrics. """ @@ -85,11 +81,11 @@ def _log_performance_metrics( ) # Train step time breakdown - if local_variables.get("perf_stats") is None: - logging.warning('"perf_stats" not set in local_variables') + if not hasattr(task, "perf_stats") or task.perf_stats is None: + logging.warning('"perf_stats" not set in task') elif task.train: logging.info( "Train step time breakdown (rank {}):\n{}".format( - get_rank(), local_variables["perf_stats"].report_str() + get_rank(), task.perf_stats.report_str() ) ) diff --git a/classy_vision/tasks/classification_task.py b/classy_vision/tasks/classification_task.py index c54c3ca7d5..3ae253be24 100644 --- a/classy_vision/tasks/classification_task.py +++ b/classy_vision/tasks/classification_task.py @@ -8,7 +8,7 @@ import enum import logging import time -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, NamedTuple, Optional, Union import torch from classy_vision.dataset import ClassyDataset, build_dataset @@ -54,6 +54,13 @@ class BroadcastBuffersMode(enum.Enum): BEFORE_EVAL = enum.auto() +class LastBatchInfo(NamedTuple): + loss: torch.Tensor + output: torch.Tensor + target: torch.Tensor + sample: Dict[str, Any] + + @register_task("classification_task") class ClassificationTask(ClassyTask): """Basic classification training task. @@ -672,6 +679,14 @@ def eval_step(self, use_gpu, local_variables=None): self.update_meters(local_variables["output"], local_variables["sample"]) + # Move some data to the task so hooks get a chance to access it + self.last_batch = LastBatchInfo( + loss=local_variables["loss"], + output=local_variables["output"], + target=local_variables["target"], + sample=local_variables["sample"], + ) + def train_step(self, use_gpu, local_variables=None): """Train step to be executed in train loop @@ -684,6 +699,8 @@ def train_step(self, use_gpu, local_variables=None): if local_variables is None: local_variables = {} + self.last_batch = None + # Process next sample sample = next(self.get_data_iterator()) local_variables["sample"] = sample @@ -738,6 +755,14 @@ def train_step(self, use_gpu, local_variables=None): self.num_updates += self.get_global_batchsize() + # Move some data to the task so hooks get a chance to access it + self.last_batch = LastBatchInfo( + loss=local_variables["loss"], + output=local_variables["output"], + target=local_variables["target"], + sample=local_variables["sample"], + ) + def compute_loss(self, model_output, sample): return self.loss(model_output, sample["target"]) diff --git a/classy_vision/tasks/classy_task.py b/classy_vision/tasks/classy_task.py index 7d24e68ddb..ee3b6652a8 100644 --- a/classy_vision/tasks/classy_task.py +++ b/classy_vision/tasks/classy_task.py @@ -178,7 +178,8 @@ def step(self, use_gpu, local_variables: Optional[Dict] = None) -> None: else: self.eval_step(use_gpu, local_variables) - self.run_hooks(local_variables, ClassyHookFunctions.on_step.name) + for hook in self.hooks: + hook.on_step(self) def run_hooks(self, local_variables: Dict[str, Any], hook_function: str) -> None: """ diff --git a/test/hooks_exponential_moving_average_model_hook_test.py b/test/hooks_exponential_moving_average_model_hook_test.py index 710403065a..4b1cab64b4 100644 --- a/test/hooks_exponential_moving_average_model_hook_test.py +++ b/test/hooks_exponential_moving_average_model_hook_test.py @@ -53,7 +53,7 @@ def _test_exponential_moving_average_hook(self, model_device, hook_device): task.base_model.update_fc_weight() fc_weight = model.fc.weight.clone() for _ in range(num_updates): - exponential_moving_average_hook.on_step(task, local_variables) + exponential_moving_average_hook.on_step(task) exponential_moving_average_hook.on_phase_end(task, local_variables) # the model weights shouldn't have changed self.assertTrue(torch.allclose(model.fc.weight, fc_weight)) diff --git a/test/hooks_loss_lr_meter_logging_hook_test.py b/test/hooks_loss_lr_meter_logging_hook_test.py index d7499fc75d..2a7b3e06f4 100644 --- a/test/hooks_loss_lr_meter_logging_hook_test.py +++ b/test/hooks_loss_lr_meter_logging_hook_test.py @@ -51,27 +51,27 @@ def test_logging(self, mock_get_rank: mock.MagicMock) -> None: for i in range(num_batches): task.losses = list(range(i)) - loss_lr_meter_hook.on_step(task, local_variables) + loss_lr_meter_hook.on_step(task) if log_freq is not None and i and i % log_freq == 0: - mock_fn.assert_called_with(task, local_variables) + mock_fn.assert_called_with(task) mock_fn.reset_mock() - mock_lr_fn.assert_called_with(task, local_variables) + mock_lr_fn.assert_called_with(task) mock_lr_fn.reset_mock() continue mock_fn.assert_not_called() mock_lr_fn.assert_not_called() loss_lr_meter_hook.on_phase_end(task, local_variables) - mock_fn.assert_called_with(task, local_variables) + mock_fn.assert_called_with(task) if task.train: - mock_lr_fn.assert_called_with(task, local_variables) + mock_lr_fn.assert_called_with(task) # test _log_loss_lr_meters() task.losses = losses with self.assertLogs(): - loss_lr_meter_hook._log_loss_meters(task, local_variables) - loss_lr_meter_hook._log_lr(task, local_variables) + loss_lr_meter_hook._log_loss_meters(task) + loss_lr_meter_hook._log_lr(task) task.phase_idx += 1 @@ -95,7 +95,7 @@ def scheduler_mock(where): lr_order = [0.0, 1 / 6, 1 / 6, 2 / 6, 3 / 6, 3 / 6, 4 / 6, 5 / 6, 5 / 6] lr_list = [] - def mock_log_lr(task: ClassyTask, local_variables) -> None: + def mock_log_lr(task: ClassyTask) -> None: lr_list.append(task.optimizer.parameters.lr) with mock.patch.object( diff --git a/test/hooks_time_metrics_hook_test.py b/test/hooks_time_metrics_hook_test.py index 296c42e561..ab9c0358bd 100644 --- a/test/hooks_time_metrics_hook_test.py +++ b/test/hooks_time_metrics_hook_test.py @@ -49,7 +49,7 @@ def test_time_metrics( mock_time.return_value = start_time time_metrics_hook.on_phase_start(task, local_variables) self.assertEqual(time_metrics_hook.start_time, start_time) - self.assertTrue(isinstance(local_variables.get("perf_stats"), PerfStats)) + self.assertTrue(isinstance(task.perf_stats, PerfStats)) # test that the code doesn't raise an exception if losses is empty try: @@ -66,15 +66,15 @@ def test_time_metrics( for i in range(num_batches): task.losses = list(range(i)) - time_metrics_hook.on_step(task, local_variables) + time_metrics_hook.on_step(task) if log_freq is not None and i and i % log_freq == 0: - mock_fn.assert_called_with(task, local_variables) + mock_fn.assert_called_with(task) mock_fn.reset_mock() continue mock_fn.assert_not_called() time_metrics_hook.on_phase_end(task, local_variables) - mock_fn.assert_called_with(task, local_variables) + mock_fn.assert_called_with(task) task.losses = [0.23, 0.45, 0.34, 0.67] @@ -84,7 +84,7 @@ def test_time_metrics( # test _log_performance_metrics() with self.assertLogs() as log_watcher: - time_metrics_hook._log_performance_metrics(task, local_variables) + time_metrics_hook._log_performance_metrics(task) # there should 2 be info logs for train and 1 for test self.assertEqual(len(log_watcher.output), 2 if train else 1) @@ -112,7 +112,7 @@ def test_time_metrics( # if on_phase_start() is not called, 2 warnings should be logged # create a new time metrics hook - local_variables = {} + task.perf_stats = None time_metrics_hook_new = TimeMetricsHook() with self.assertLogs() as log_watcher: diff --git a/test/manual/hooks_progress_bar_hook_test.py b/test/manual/hooks_progress_bar_hook_test.py index 3b53fca7c4..bfbd3a9747 100644 --- a/test/manual/hooks_progress_bar_hook_test.py +++ b/test/manual/hooks_progress_bar_hook_test.py @@ -48,14 +48,14 @@ def test_progress_bar( # on_step should update the progress bar correctly for i in range(num_batches): - progress_bar_hook.on_step(task, local_variables) + progress_bar_hook.on_step(task) mock_progress_bar.update.assert_called_once_with(i + 1) mock_progress_bar.update.reset_mock() # check that even if on_step is called again, the progress bar is # only updated with num_batches for _ in range(num_batches): - progress_bar_hook.on_step(task, local_variables) + progress_bar_hook.on_step(task) mock_progress_bar.update.assert_called_once_with(num_batches) mock_progress_bar.update.reset_mock() @@ -68,7 +68,7 @@ def test_progress_bar( # crash progress_bar_hook = ProgressBarHook() try: - progress_bar_hook.on_step(task, local_variables) + progress_bar_hook.on_step(task) progress_bar_hook.on_phase_end(task, local_variables) except Exception as e: self.fail( @@ -81,7 +81,7 @@ def test_progress_bar( progress_bar_hook = ProgressBarHook() try: progress_bar_hook.on_phase_start(task, local_variables) - progress_bar_hook.on_step(task, local_variables) + progress_bar_hook.on_step(task) progress_bar_hook.on_phase_end(task, local_variables) except Exception as e: self.fail("Received Exception when is_master() is False: {}".format(e)) diff --git a/test/manual/hooks_tensorboard_plot_hook_test.py b/test/manual/hooks_tensorboard_plot_hook_test.py index be125f0637..033fa46d70 100644 --- a/test/manual/hooks_tensorboard_plot_hook_test.py +++ b/test/manual/hooks_tensorboard_plot_hook_test.py @@ -62,7 +62,7 @@ def test_writer(self, mock_is_master_func: mock.MagicMock) -> None: # the writer if on_phase_start() is not called for initialization # before on_step() is called. with self.assertLogs() as log_watcher: - tensorboard_plot_hook.on_step(task, local_variables) + tensorboard_plot_hook.on_step(task) self.assertTrue( len(log_watcher.records) == 1 @@ -88,7 +88,7 @@ def test_writer(self, mock_is_master_func: mock.MagicMock) -> None: for loss in losses: task.losses.append(loss) - tensorboard_plot_hook.on_step(task, local_variables) + tensorboard_plot_hook.on_step(task) tensorboard_plot_hook.on_phase_end(task, local_variables) diff --git a/test/optim_param_scheduler_test.py b/test/optim_param_scheduler_test.py index 8b9ad6c486..08aca43ff9 100644 --- a/test/optim_param_scheduler_test.py +++ b/test/optim_param_scheduler_test.py @@ -207,7 +207,7 @@ class TestHook(ClassyHook): on_phase_end = ClassyHook._noop on_end = ClassyHook._noop - def on_step(self, task: ClassyTask, local_variables) -> None: + def on_step(self, task: ClassyTask) -> None: if not task.train: return From d2cd6f92947e064e53f97c09811c3cb3db8beb9b Mon Sep 17 00:00:00 2001 From: Vinicius Reis Date: Mon, 2 Mar 2020 09:56:49 -0800 Subject: [PATCH 2/2] Remove local_variables from train_step/eval_step Summary: This is part of a series of diffs to eliminate local_variables (see D20171981). Now that we've removed local_variables from step, remove it from train_step, eval_step. Differential Revision: D20170006 fbshipit-source-id: 3d044c9e383601662ed19c05bc05b1ee23667709 --- classy_vision/tasks/classification_task.py | 90 +++++++--------------- classy_vision/tasks/classy_task.py | 14 ++-- classy_vision/trainer/classy_trainer.py | 2 +- classy_vision/trainer/elastic_trainer.py | 2 +- test/tasks_classification_task_test.py | 5 +- 5 files changed, 35 insertions(+), 78 deletions(-) diff --git a/classy_vision/tasks/classification_task.py b/classy_vision/tasks/classification_task.py index 3ae253be24..b476a43e26 100644 --- a/classy_vision/tasks/classification_task.py +++ b/classy_vision/tasks/classification_task.py @@ -637,118 +637,83 @@ def set_classy_state(self, state): # Set up pytorch module in train vs eval mode, update optimizer. self._set_model_train_mode() - def eval_step(self, use_gpu, local_variables=None): - if local_variables is None: - local_variables = {} - + def eval_step(self, use_gpu): # Process next sample sample = next(self.get_data_iterator()) - local_variables["sample"] = sample - assert ( - isinstance(local_variables["sample"], dict) - and "input" in local_variables["sample"] - and "target" in local_variables["sample"] - ), ( + assert isinstance(sample, dict) and "input" in sample and "target" in sample, ( f"Returned sample [{sample}] is not a map with 'input' and" + "'target' keys" ) # Copy sample to GPU - local_variables["target"] = local_variables["sample"]["target"] + target = sample["target"] if use_gpu: - for key, value in local_variables["sample"].items(): - local_variables["sample"][key] = recursive_copy_to_gpu( - value, non_blocking=True - ) + for key, value in sample.items(): + sample[key] = recursive_copy_to_gpu(value, non_blocking=True) with torch.no_grad(): - local_variables["output"] = self.model(local_variables["sample"]["input"]) + output = self.model(sample["input"]) - local_variables["local_loss"] = self.compute_loss( - local_variables["output"], local_variables["sample"] - ) + local_loss = self.compute_loss(output, sample) - local_variables["loss"] = local_variables["local_loss"].detach().clone() - local_variables["loss"] = all_reduce_mean(local_variables["loss"]) + loss = local_loss.detach().clone() + loss = all_reduce_mean(loss) - self.losses.append( - local_variables["loss"].data.cpu().item() - * local_variables["target"].size(0) - ) + self.losses.append(loss.data.cpu().item() * target.size(0)) - self.update_meters(local_variables["output"], local_variables["sample"]) + self.update_meters(output, sample) # Move some data to the task so hooks get a chance to access it self.last_batch = LastBatchInfo( - loss=local_variables["loss"], - output=local_variables["output"], - target=local_variables["target"], - sample=local_variables["sample"], + loss=loss, output=output, target=target, sample=sample ) - def train_step(self, use_gpu, local_variables=None): + def train_step(self, use_gpu): """Train step to be executed in train loop Args: use_gpu: if true, execute training on GPU - local_variables: Dict containing intermediate values - in train_step for access by hooks """ - if local_variables is None: - local_variables = {} - self.last_batch = None # Process next sample sample = next(self.get_data_iterator()) - local_variables["sample"] = sample - assert ( - isinstance(local_variables["sample"], dict) - and "input" in local_variables["sample"] - and "target" in local_variables["sample"] - ), ( + assert isinstance(sample, dict) and "input" in sample and "target" in sample, ( f"Returned sample [{sample}] is not a map with 'input' and" + "'target' keys" ) # Copy sample to GPU - local_variables["target"] = local_variables["sample"]["target"] + target = sample["target"] if use_gpu: - for key, value in local_variables["sample"].items(): - local_variables["sample"][key] = recursive_copy_to_gpu( - value, non_blocking=True - ) + for key, value in sample.items(): + sample[key] = recursive_copy_to_gpu(value, non_blocking=True) with torch.enable_grad(): # Forward pass - local_variables["output"] = self.model(local_variables["sample"]["input"]) + output = self.model(sample["input"]) - local_variables["local_loss"] = self.compute_loss( - local_variables["output"], local_variables["sample"] - ) + local_loss = self.compute_loss(output, sample) - local_variables["loss"] = local_variables["local_loss"].detach().clone() - local_variables["loss"] = all_reduce_mean(local_variables["loss"]) + loss = local_loss.detach().clone() + loss = all_reduce_mean(loss) - self.losses.append( - local_variables["loss"].data.cpu().item() - * local_variables["target"].size(0) - ) + self.losses.append(loss.data.cpu().item() * target.size(0)) - self.update_meters(local_variables["output"], local_variables["sample"]) + self.update_meters(output, sample) # Run backwards pass / update optimizer if self.amp_opt_level is not None: self.optimizer.zero_grad() with apex.amp.scale_loss( - local_variables["local_loss"], self.optimizer.optimizer + local_loss, self.optimizer.optimizer ) as scaled_loss: scaled_loss.backward() else: - self.optimizer.backward(local_variables["local_loss"]) + self.optimizer.backward(local_loss) self.optimizer.update_schedule_on_step(self.where) self.optimizer.step() @@ -757,10 +722,7 @@ def train_step(self, use_gpu, local_variables=None): # Move some data to the task so hooks get a chance to access it self.last_batch = LastBatchInfo( - loss=local_variables["loss"], - output=local_variables["output"], - target=local_variables["target"], - sample=local_variables["sample"], + loss=loss, output=output, target=target, sample=sample ) def compute_loss(self, model_output, sample): diff --git a/classy_vision/tasks/classy_task.py b/classy_vision/tasks/classy_task.py index ee3b6652a8..e8478c3e5a 100644 --- a/classy_vision/tasks/classy_task.py +++ b/classy_vision/tasks/classy_task.py @@ -107,7 +107,7 @@ def prepare( pass @abstractmethod - def train_step(self, use_gpu, local_variables: Optional[Dict] = None) -> None: + def train_step(self, use_gpu) -> None: """ Run a train step. @@ -115,8 +115,6 @@ def train_step(self, use_gpu, local_variables: Optional[Dict] = None) -> None: Args: use_gpu: True if training on GPUs, False otherwise - local_variables: Local variables created in the function. Can be passed to - custom :class:`classy_vision.hooks.ClassyHook`. """ pass @@ -157,7 +155,7 @@ def on_end(self, local_variables): pass @abstractmethod - def eval_step(self, use_gpu, local_variables: Optional[Dict] = None) -> None: + def eval_step(self, use_gpu) -> None: """ Run an evaluation step. @@ -165,18 +163,16 @@ def eval_step(self, use_gpu, local_variables: Optional[Dict] = None) -> None: Args: use_gpu: True if training on GPUs, False otherwise - local_variables: Local variables created in the function. Can be passed to - custom :class:`classy_vision.hooks.ClassyHook`. """ pass - def step(self, use_gpu, local_variables: Optional[Dict] = None) -> None: + def step(self, use_gpu) -> None: from classy_vision.hooks import ClassyHookFunctions if self.train: - self.train_step(use_gpu, local_variables) + self.train_step(use_gpu) else: - self.eval_step(use_gpu, local_variables) + self.eval_step(use_gpu) for hook in self.hooks: hook.on_step(self) diff --git a/classy_vision/trainer/classy_trainer.py b/classy_vision/trainer/classy_trainer.py index 28ec02bb9d..25c6f3357d 100644 --- a/classy_vision/trainer/classy_trainer.py +++ b/classy_vision/trainer/classy_trainer.py @@ -77,7 +77,7 @@ def train(self, task: ClassyTask): task.on_phase_start(local_variables) while True: try: - task.step(self.use_gpu, local_variables) + task.step(self.use_gpu) except StopIteration: break task.on_phase_end(local_variables) diff --git a/classy_vision/trainer/elastic_trainer.py b/classy_vision/trainer/elastic_trainer.py index 36f445a05b..c604beb40a 100644 --- a/classy_vision/trainer/elastic_trainer.py +++ b/classy_vision/trainer/elastic_trainer.py @@ -106,7 +106,7 @@ def _run_step(self, state, local_variables, use_gpu): state.advance_to_next_phase = True state.skip_current_phase = False # Reset flag else: - state.task.step(use_gpu, local_variables) + state.task.step(use_gpu) except StopIteration: state.advance_to_next_phase = True diff --git a/test/tasks_classification_task_test.py b/test/tasks_classification_task_test.py index d9ea87ed4d..44617e3e46 100644 --- a/test/tasks_classification_task_test.py +++ b/test/tasks_classification_task_test.py @@ -73,7 +73,6 @@ def test_checkpointing(self): task_2 = build_task(config).set_hooks([LossLrMeterLoggingHook()]) use_gpu = torch.cuda.is_available() - local_variables = {} # prepare the tasks for the right device task.prepare(use_gpu=use_gpu) @@ -96,8 +95,8 @@ def test_checkpointing(self): # test that the train step runs the same way on both states # and the loss remains the same - task.train_step(use_gpu, local_variables) - task_2.train_step(use_gpu, local_variables) + task.train_step(use_gpu) + task_2.train_step(use_gpu) self._compare_states(task.get_classy_state(), task_2.get_classy_state()) def test_final_train_checkpoint(self):