Skip to content
This repository has been archived by the owner on Jul 1, 2024. It is now read-only.

Remove local_variables from on_start #413

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions classy_vision/hooks/checkpoint_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,7 @@ def _save_checkpoint(self, task, filename):
if checkpoint_file:
PathManager.copy(checkpoint_file, f"{self.checkpoint_folder}/{filename}")

def on_start(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
) -> None:
def on_start(self, task: "tasks.ClassyTask") -> None:
if not is_master() or getattr(task, "test_only", False):
return
if not PathManager.exists(self.checkpoint_folder):
Expand Down
10 changes: 3 additions & 7 deletions classy_vision/hooks/classy_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(self, a, b):
def __init__(self):
self.state = ClassyHookState()

def _noop(self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]) -> None:
def _noop(self, *args, **kwargs) -> None:
"""Derived classes can set their hook functions to this.

This is useful if they want those hook functions to not do anything.
Expand All @@ -65,9 +65,7 @@ def name(cls) -> str:
return cls.__name__

@abstractmethod
def on_start(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
) -> None:
def on_start(self, task: "tasks.ClassyTask") -> None:
"""Called at the start of training."""
pass

Expand All @@ -79,9 +77,7 @@ def on_phase_start(
pass

@abstractmethod
def on_step(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
) -> None:
def on_step(self, task: "tasks.ClassyTask") -> None:
"""Called each time after parameters have been updated by the optimizer."""
pass

Expand Down
4 changes: 2 additions & 2 deletions classy_vision/hooks/exponential_moving_average_model_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def _save_current_model_state(self, model: nn.Module, model_state: Dict[str, Any
for name, param in self.get_model_state_iterator(model):
model_state[name] = param.detach().clone().to(device=self.device)

def on_start(self, task: ClassyTask, local_variables: Dict[str, Any]) -> None:
def on_start(self, task: ClassyTask) -> None:
if self.state.model_state:
# loaded state from checkpoint, do not re-initialize, only move the state
# to the right device
Expand All @@ -103,7 +103,7 @@ def on_phase_end(self, task: ClassyTask, local_variables: Dict[str, Any]) -> Non
# state in the test phase
self._save_current_model_state(task.base_model, self.state.model_state)

def on_step(self, task: ClassyTask, local_variables: Dict[str, Any]) -> None:
def on_step(self, task: ClassyTask) -> None:
if not task.train:
return

Expand Down
20 changes: 7 additions & 13 deletions classy_vision/hooks/loss_lr_meter_logging_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,36 +45,30 @@ def on_phase_end(
# trainer to implement an unsynced end of phase meter or
# for meters to not provide a sync function.
logging.info("End of phase metric values:")
self._log_loss_meters(task, local_variables)
self._log_loss_meters(task)
if task.train:
self._log_lr(task, local_variables)
self._log_lr(task)

def on_step(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
) -> None:
def on_step(self, task: "tasks.ClassyTask") -> None:
"""
Log the LR every log_freq batches, if log_freq is not None.
"""
if self.log_freq is None or not task.train:
return
batches = len(task.losses)
if batches and batches % self.log_freq == 0:
self._log_lr(task, local_variables)
self._log_lr(task)
logging.info("Local unsynced metric values:")
self._log_loss_meters(task, local_variables)
self._log_loss_meters(task)

def _log_lr(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
) -> None:
def _log_lr(self, task: "tasks.ClassyTask") -> None:
"""
Compute and log the optimizer LR.
"""
optimizer_lr = task.optimizer.parameters.lr
logging.info("Learning Rate: {}\n".format(optimizer_lr))

def _log_loss_meters(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
) -> None:
def _log_loss_meters(self, task: "tasks.ClassyTask") -> None:
"""
Compute and log the loss and meters.
"""
Expand Down
4 changes: 1 addition & 3 deletions classy_vision/hooks/model_complexity_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@ class ModelComplexityHook(ClassyHook):
on_phase_end = ClassyHook._noop
on_end = ClassyHook._noop

def on_start(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
) -> None:
def on_start(self, task: "tasks.ClassyTask") -> None:
"""Measure number of parameters, FLOPs and activations."""
self.num_flops = 0
self.num_activations = 0
Expand Down
4 changes: 1 addition & 3 deletions classy_vision/hooks/model_tensorboard_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,7 @@ def __init__(self, tb_writer) -> None:

self.tb_writer = tb_writer

def on_start(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
) -> None:
def on_start(self, task: "tasks.ClassyTask") -> None:
"""
Plot the model on Tensorboard.
"""
Expand Down
4 changes: 1 addition & 3 deletions classy_vision/hooks/profiler_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@ class ProfilerHook(ClassyHook):
on_phase_end = ClassyHook._noop
on_end = ClassyHook._noop

def on_start(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
) -> None:
def on_start(self, task: "tasks.ClassyTask") -> None:
"""Profile the forward pass."""
logging.info("Profiling forward pass...")
batchsize_per_replica = getattr(
Expand Down
4 changes: 1 addition & 3 deletions classy_vision/hooks/progress_bar_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,7 @@ def on_phase_start(
self.progress_bar = progressbar.ProgressBar(self.bar_size)
self.progress_bar.start()

def on_step(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
) -> None:
def on_step(self, task: "tasks.ClassyTask") -> None:
"""Update the progress bar with the batch size."""
if task.train and is_master() and self.progress_bar is not None:
self.batches += 1
Expand Down
4 changes: 1 addition & 3 deletions classy_vision/hooks/tensorboard_plot_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,7 @@ def on_phase_start(
self.wall_times = []
self.num_steps_global = []

def on_step(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
) -> None:
def on_step(self, task: "tasks.ClassyTask") -> None:
"""Store the observed learning rates."""
if self.learning_rates is None:
logging.warning("learning_rates is not initialized")
Expand Down
20 changes: 8 additions & 12 deletions classy_vision/hooks/time_metrics_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,17 @@ def on_phase_start(
Initialize start time and reset perf stats
"""
self.start_time = time.time()
local_variables["perf_stats"] = PerfStats()
task.perf_stats = PerfStats()

def on_step(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
) -> None:
def on_step(self, task: "tasks.ClassyTask") -> None:
"""
Log metrics every log_freq batches, if log_freq is not None.
"""
if self.log_freq is None:
return
batches = len(task.losses)
if batches and batches % self.log_freq == 0:
self._log_performance_metrics(task, local_variables)
self._log_performance_metrics(task)

def on_phase_end(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
Expand All @@ -62,11 +60,9 @@ def on_phase_end(
"""
batches = len(task.losses)
if batches:
self._log_performance_metrics(task, local_variables)
self._log_performance_metrics(task)

def _log_performance_metrics(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
) -> None:
def _log_performance_metrics(self, task: "tasks.ClassyTask") -> None:
"""
Compute and log performance metrics.
"""
Expand All @@ -85,11 +81,11 @@ def _log_performance_metrics(
)

# Train step time breakdown
if local_variables.get("perf_stats") is None:
logging.warning('"perf_stats" not set in local_variables')
if not hasattr(task, "perf_stats") or task.perf_stats is None:
logging.warning('"perf_stats" not set in task')
elif task.train:
logging.info(
"Train step time breakdown (rank {}):\n{}".format(
get_rank(), local_variables["perf_stats"].report_str()
get_rank(), task.perf_stats.report_str()
)
)
104 changes: 46 additions & 58 deletions classy_vision/tasks/classification_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import enum
import logging
import time
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, NamedTuple, Optional, Union

import torch
from classy_vision.dataset import ClassyDataset, build_dataset
Expand Down Expand Up @@ -54,6 +54,13 @@ class BroadcastBuffersMode(enum.Enum):
BEFORE_EVAL = enum.auto()


class LastBatchInfo(NamedTuple):
loss: torch.Tensor
output: torch.Tensor
target: torch.Tensor
sample: Dict[str, Any]


@register_task("classification_task")
class ClassificationTask(ClassyTask):
"""Basic classification training task.
Expand Down Expand Up @@ -630,114 +637,94 @@ def set_classy_state(self, state):
# Set up pytorch module in train vs eval mode, update optimizer.
self._set_model_train_mode()

def eval_step(self, use_gpu, local_variables=None):
if local_variables is None:
local_variables = {}

def eval_step(self, use_gpu):
# Process next sample
sample = next(self.get_data_iterator())
local_variables["sample"] = sample

assert (
isinstance(local_variables["sample"], dict)
and "input" in local_variables["sample"]
and "target" in local_variables["sample"]
), (
assert isinstance(sample, dict) and "input" in sample and "target" in sample, (
f"Returned sample [{sample}] is not a map with 'input' and"
+ "'target' keys"
)

# Copy sample to GPU
local_variables["target"] = local_variables["sample"]["target"]
target = sample["target"]
if use_gpu:
for key, value in local_variables["sample"].items():
local_variables["sample"][key] = recursive_copy_to_gpu(
value, non_blocking=True
)
for key, value in sample.items():
sample[key] = recursive_copy_to_gpu(value, non_blocking=True)

with torch.no_grad():
local_variables["output"] = self.model(local_variables["sample"]["input"])
output = self.model(sample["input"])

local_variables["local_loss"] = self.compute_loss(
local_variables["output"], local_variables["sample"]
)
local_loss = self.compute_loss(output, sample)

local_variables["loss"] = local_variables["local_loss"].detach().clone()
local_variables["loss"] = all_reduce_mean(local_variables["loss"])
loss = local_loss.detach().clone()
loss = all_reduce_mean(loss)

self.losses.append(
local_variables["loss"].data.cpu().item()
* local_variables["target"].size(0)
)
self.losses.append(loss.data.cpu().item() * target.size(0))

self.update_meters(output, sample)

self.update_meters(local_variables["output"], local_variables["sample"])
# Move some data to the task so hooks get a chance to access it
self.last_batch = LastBatchInfo(
loss=loss, output=output, target=target, sample=sample
)

def train_step(self, use_gpu, local_variables=None):
def train_step(self, use_gpu):
"""Train step to be executed in train loop

Args:
use_gpu: if true, execute training on GPU
local_variables: Dict containing intermediate values
in train_step for access by hooks
"""

if local_variables is None:
local_variables = {}
self.last_batch = None

# Process next sample
sample = next(self.get_data_iterator())
local_variables["sample"] = sample

assert (
isinstance(local_variables["sample"], dict)
and "input" in local_variables["sample"]
and "target" in local_variables["sample"]
), (
assert isinstance(sample, dict) and "input" in sample and "target" in sample, (
f"Returned sample [{sample}] is not a map with 'input' and"
+ "'target' keys"
)

# Copy sample to GPU
local_variables["target"] = local_variables["sample"]["target"]
target = sample["target"]
if use_gpu:
for key, value in local_variables["sample"].items():
local_variables["sample"][key] = recursive_copy_to_gpu(
value, non_blocking=True
)
for key, value in sample.items():
sample[key] = recursive_copy_to_gpu(value, non_blocking=True)

with torch.enable_grad():
# Forward pass
local_variables["output"] = self.model(local_variables["sample"]["input"])
output = self.model(sample["input"])

local_variables["local_loss"] = self.compute_loss(
local_variables["output"], local_variables["sample"]
)
local_loss = self.compute_loss(output, sample)

local_variables["loss"] = local_variables["local_loss"].detach().clone()
local_variables["loss"] = all_reduce_mean(local_variables["loss"])
loss = local_loss.detach().clone()
loss = all_reduce_mean(loss)

self.losses.append(
local_variables["loss"].data.cpu().item()
* local_variables["target"].size(0)
)
self.losses.append(loss.data.cpu().item() * target.size(0))

self.update_meters(local_variables["output"], local_variables["sample"])
self.update_meters(output, sample)

# Run backwards pass / update optimizer
if self.amp_opt_level is not None:
self.optimizer.zero_grad()
with apex.amp.scale_loss(
local_variables["local_loss"], self.optimizer.optimizer
local_loss, self.optimizer.optimizer
) as scaled_loss:
scaled_loss.backward()
else:
self.optimizer.backward(local_variables["local_loss"])
self.optimizer.backward(local_loss)

self.optimizer.update_schedule_on_step(self.where)
self.optimizer.step()

self.num_updates += self.get_global_batchsize()

# Move some data to the task so hooks get a chance to access it
self.last_batch = LastBatchInfo(
loss=loss, output=output, target=target, sample=sample
)

def compute_loss(self, model_output, sample):
return self.loss(model_output, sample["target"])

Expand Down Expand Up @@ -864,8 +851,9 @@ def get_global_batchsize(self):
"""
return self.dataloaders[self.phase_type].dataset.get_global_batchsize()

def on_start(self, local_variables):
self.run_hooks(local_variables, ClassyHookFunctions.on_start.name)
def on_start(self):
for hook in self.hooks:
hook.on_start(self)

def on_phase_start(self, local_variables):
self.phase_start_time_total = time.perf_counter()
Expand Down
Loading