Skip to content
This repository has been archived by the owner on Jul 1, 2024. It is now read-only.

Remove local_variables from train_step/eval_step #412

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions classy_vision/hooks/classy_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(self, a, b):
def __init__(self):
self.state = ClassyHookState()

def _noop(self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]) -> None:
def _noop(self, *args, **kwargs) -> None:
"""Derived classes can set their hook functions to this.

This is useful if they want those hook functions to not do anything.
Expand Down Expand Up @@ -79,9 +79,7 @@ def on_phase_start(
pass

@abstractmethod
def on_step(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
) -> None:
def on_step(self, task: "tasks.ClassyTask") -> None:
"""Called each time after parameters have been updated by the optimizer."""
pass

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def on_phase_end(self, task: ClassyTask, local_variables: Dict[str, Any]) -> Non
# state in the test phase
self._save_current_model_state(task.base_model, self.state.model_state)

def on_step(self, task: ClassyTask, local_variables: Dict[str, Any]) -> None:
def on_step(self, task: ClassyTask) -> None:
if not task.train:
return

Expand Down
20 changes: 7 additions & 13 deletions classy_vision/hooks/loss_lr_meter_logging_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,36 +45,30 @@ def on_phase_end(
# trainer to implement an unsynced end of phase meter or
# for meters to not provide a sync function.
logging.info("End of phase metric values:")
self._log_loss_meters(task, local_variables)
self._log_loss_meters(task)
if task.train:
self._log_lr(task, local_variables)
self._log_lr(task)

def on_step(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
) -> None:
def on_step(self, task: "tasks.ClassyTask") -> None:
"""
Log the LR every log_freq batches, if log_freq is not None.
"""
if self.log_freq is None or not task.train:
return
batches = len(task.losses)
if batches and batches % self.log_freq == 0:
self._log_lr(task, local_variables)
self._log_lr(task)
logging.info("Local unsynced metric values:")
self._log_loss_meters(task, local_variables)
self._log_loss_meters(task)

def _log_lr(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
) -> None:
def _log_lr(self, task: "tasks.ClassyTask") -> None:
"""
Compute and log the optimizer LR.
"""
optimizer_lr = task.optimizer.parameters.lr
logging.info("Learning Rate: {}\n".format(optimizer_lr))

def _log_loss_meters(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
) -> None:
def _log_loss_meters(self, task: "tasks.ClassyTask") -> None:
"""
Compute and log the loss and meters.
"""
Expand Down
4 changes: 1 addition & 3 deletions classy_vision/hooks/progress_bar_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,7 @@ def on_phase_start(
self.progress_bar = progressbar.ProgressBar(self.bar_size)
self.progress_bar.start()

def on_step(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
) -> None:
def on_step(self, task: "tasks.ClassyTask") -> None:
"""Update the progress bar with the batch size."""
if task.train and is_master() and self.progress_bar is not None:
self.batches += 1
Expand Down
4 changes: 1 addition & 3 deletions classy_vision/hooks/tensorboard_plot_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,7 @@ def on_phase_start(
self.wall_times = []
self.num_steps_global = []

def on_step(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
) -> None:
def on_step(self, task: "tasks.ClassyTask") -> None:
"""Store the observed learning rates."""
if self.learning_rates is None:
logging.warning("learning_rates is not initialized")
Expand Down
20 changes: 8 additions & 12 deletions classy_vision/hooks/time_metrics_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,17 @@ def on_phase_start(
Initialize start time and reset perf stats
"""
self.start_time = time.time()
local_variables["perf_stats"] = PerfStats()
task.perf_stats = PerfStats()

def on_step(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
) -> None:
def on_step(self, task: "tasks.ClassyTask") -> None:
"""
Log metrics every log_freq batches, if log_freq is not None.
"""
if self.log_freq is None:
return
batches = len(task.losses)
if batches and batches % self.log_freq == 0:
self._log_performance_metrics(task, local_variables)
self._log_performance_metrics(task)

def on_phase_end(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
Expand All @@ -62,11 +60,9 @@ def on_phase_end(
"""
batches = len(task.losses)
if batches:
self._log_performance_metrics(task, local_variables)
self._log_performance_metrics(task)

def _log_performance_metrics(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
) -> None:
def _log_performance_metrics(self, task: "tasks.ClassyTask") -> None:
"""
Compute and log performance metrics.
"""
Expand All @@ -85,11 +81,11 @@ def _log_performance_metrics(
)

# Train step time breakdown
if local_variables.get("perf_stats") is None:
logging.warning('"perf_stats" not set in local_variables')
if not hasattr(task, "perf_stats") or task.perf_stats is None:
logging.warning('"perf_stats" not set in task')
elif task.train:
logging.info(
"Train step time breakdown (rank {}):\n{}".format(
get_rank(), local_variables["perf_stats"].report_str()
get_rank(), task.perf_stats.report_str()
)
)
99 changes: 43 additions & 56 deletions classy_vision/tasks/classification_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import enum
import logging
import time
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, NamedTuple, Optional, Union

import torch
from classy_vision.dataset import ClassyDataset, build_dataset
Expand Down Expand Up @@ -54,6 +54,13 @@ class BroadcastBuffersMode(enum.Enum):
BEFORE_EVAL = enum.auto()


class LastBatchInfo(NamedTuple):
loss: torch.Tensor
output: torch.Tensor
target: torch.Tensor
sample: Dict[str, Any]


@register_task("classification_task")
class ClassificationTask(ClassyTask):
"""Basic classification training task.
Expand Down Expand Up @@ -630,114 +637,94 @@ def set_classy_state(self, state):
# Set up pytorch module in train vs eval mode, update optimizer.
self._set_model_train_mode()

def eval_step(self, use_gpu, local_variables=None):
if local_variables is None:
local_variables = {}

def eval_step(self, use_gpu):
# Process next sample
sample = next(self.get_data_iterator())
local_variables["sample"] = sample

assert (
isinstance(local_variables["sample"], dict)
and "input" in local_variables["sample"]
and "target" in local_variables["sample"]
), (
assert isinstance(sample, dict) and "input" in sample and "target" in sample, (
f"Returned sample [{sample}] is not a map with 'input' and"
+ "'target' keys"
)

# Copy sample to GPU
local_variables["target"] = local_variables["sample"]["target"]
target = sample["target"]
if use_gpu:
for key, value in local_variables["sample"].items():
local_variables["sample"][key] = recursive_copy_to_gpu(
value, non_blocking=True
)
for key, value in sample.items():
sample[key] = recursive_copy_to_gpu(value, non_blocking=True)

with torch.no_grad():
local_variables["output"] = self.model(local_variables["sample"]["input"])
output = self.model(sample["input"])

local_variables["local_loss"] = self.compute_loss(
local_variables["output"], local_variables["sample"]
)
local_loss = self.compute_loss(output, sample)

local_variables["loss"] = local_variables["local_loss"].detach().clone()
local_variables["loss"] = all_reduce_mean(local_variables["loss"])
loss = local_loss.detach().clone()
loss = all_reduce_mean(loss)

self.losses.append(
local_variables["loss"].data.cpu().item()
* local_variables["target"].size(0)
)
self.losses.append(loss.data.cpu().item() * target.size(0))

self.update_meters(local_variables["output"], local_variables["sample"])
self.update_meters(output, sample)

def train_step(self, use_gpu, local_variables=None):
# Move some data to the task so hooks get a chance to access it
self.last_batch = LastBatchInfo(
loss=loss, output=output, target=target, sample=sample
)

def train_step(self, use_gpu):
"""Train step to be executed in train loop

Args:
use_gpu: if true, execute training on GPU
local_variables: Dict containing intermediate values
in train_step for access by hooks
"""

if local_variables is None:
local_variables = {}
self.last_batch = None

# Process next sample
sample = next(self.get_data_iterator())
local_variables["sample"] = sample

assert (
isinstance(local_variables["sample"], dict)
and "input" in local_variables["sample"]
and "target" in local_variables["sample"]
), (
assert isinstance(sample, dict) and "input" in sample and "target" in sample, (
f"Returned sample [{sample}] is not a map with 'input' and"
+ "'target' keys"
)

# Copy sample to GPU
local_variables["target"] = local_variables["sample"]["target"]
target = sample["target"]
if use_gpu:
for key, value in local_variables["sample"].items():
local_variables["sample"][key] = recursive_copy_to_gpu(
value, non_blocking=True
)
for key, value in sample.items():
sample[key] = recursive_copy_to_gpu(value, non_blocking=True)

with torch.enable_grad():
# Forward pass
local_variables["output"] = self.model(local_variables["sample"]["input"])
output = self.model(sample["input"])

local_variables["local_loss"] = self.compute_loss(
local_variables["output"], local_variables["sample"]
)
local_loss = self.compute_loss(output, sample)

local_variables["loss"] = local_variables["local_loss"].detach().clone()
local_variables["loss"] = all_reduce_mean(local_variables["loss"])
loss = local_loss.detach().clone()
loss = all_reduce_mean(loss)

self.losses.append(
local_variables["loss"].data.cpu().item()
* local_variables["target"].size(0)
)
self.losses.append(loss.data.cpu().item() * target.size(0))

self.update_meters(local_variables["output"], local_variables["sample"])
self.update_meters(output, sample)

# Run backwards pass / update optimizer
if self.amp_opt_level is not None:
self.optimizer.zero_grad()
with apex.amp.scale_loss(
local_variables["local_loss"], self.optimizer.optimizer
local_loss, self.optimizer.optimizer
) as scaled_loss:
scaled_loss.backward()
else:
self.optimizer.backward(local_variables["local_loss"])
self.optimizer.backward(local_loss)

self.optimizer.update_schedule_on_step(self.where)
self.optimizer.step()

self.num_updates += self.get_global_batchsize()

# Move some data to the task so hooks get a chance to access it
self.last_batch = LastBatchInfo(
loss=loss, output=output, target=target, sample=sample
)

def compute_loss(self, model_output, sample):
return self.loss(model_output, sample["target"])

Expand Down
17 changes: 7 additions & 10 deletions classy_vision/tasks/classy_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,16 +107,14 @@ def prepare(
pass

@abstractmethod
def train_step(self, use_gpu, local_variables: Optional[Dict] = None) -> None:
def train_step(self, use_gpu) -> None:
"""
Run a train step.

This corresponds to training over one batch of data from the dataloaders.

Args:
use_gpu: True if training on GPUs, False otherwise
local_variables: Local variables created in the function. Can be passed to
custom :class:`classy_vision.hooks.ClassyHook`.
"""
pass

Expand Down Expand Up @@ -157,28 +155,27 @@ def on_end(self, local_variables):
pass

@abstractmethod
def eval_step(self, use_gpu, local_variables: Optional[Dict] = None) -> None:
def eval_step(self, use_gpu) -> None:
"""
Run an evaluation step.

This corresponds to evaluating the model over one batch of data.

Args:
use_gpu: True if training on GPUs, False otherwise
local_variables: Local variables created in the function. Can be passed to
custom :class:`classy_vision.hooks.ClassyHook`.
"""
pass

def step(self, use_gpu, local_variables: Optional[Dict] = None) -> None:
def step(self, use_gpu) -> None:
from classy_vision.hooks import ClassyHookFunctions

if self.train:
self.train_step(use_gpu, local_variables)
self.train_step(use_gpu)
else:
self.eval_step(use_gpu, local_variables)
self.eval_step(use_gpu)

self.run_hooks(local_variables, ClassyHookFunctions.on_step.name)
for hook in self.hooks:
hook.on_step(self)

def run_hooks(self, local_variables: Dict[str, Any], hook_function: str) -> None:
"""
Expand Down
2 changes: 1 addition & 1 deletion classy_vision/trainer/classy_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def train(self, task: ClassyTask):
task.on_phase_start(local_variables)
while True:
try:
task.step(self.use_gpu, local_variables)
task.step(self.use_gpu)
except StopIteration:
break
task.on_phase_end(local_variables)
Expand Down
2 changes: 1 addition & 1 deletion classy_vision/trainer/elastic_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def _run_step(self, state, local_variables, use_gpu):
state.advance_to_next_phase = True
state.skip_current_phase = False # Reset flag
else:
state.task.step(use_gpu, local_variables)
state.task.step(use_gpu)
except StopIteration:
state.advance_to_next_phase = True

Expand Down
2 changes: 1 addition & 1 deletion test/hooks_exponential_moving_average_model_hook_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def _test_exponential_moving_average_hook(self, model_device, hook_device):
task.base_model.update_fc_weight()
fc_weight = model.fc.weight.clone()
for _ in range(num_updates):
exponential_moving_average_hook.on_step(task, local_variables)
exponential_moving_average_hook.on_step(task)
exponential_moving_average_hook.on_phase_end(task, local_variables)
# the model weights shouldn't have changed
self.assertTrue(torch.allclose(model.fc.weight, fc_weight))
Expand Down
Loading