Skip to content
This repository has been archived by the owner on Jul 1, 2024. It is now read-only.

Commit

Permalink
Remove local_variables from on_phase_end (#421)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #421

This is part of a series of diffs to eliminate local_variables (see D20171981).
Final step! Remove local_variables from on_phase_end.

Reviewed By: mannatsingh

Differential Revision: D20178293

fbshipit-source-id: 635001ba5331a1f43581c653c8f54932227c28f8
  • Loading branch information
vreis authored and facebook-github-bot committed Mar 6, 2020
1 parent 0b2368e commit 636740b
Show file tree
Hide file tree
Showing 20 changed files with 37 additions and 68 deletions.
4 changes: 1 addition & 3 deletions classy_vision/hooks/checkpoint_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,7 @@ def on_start(self, task: "tasks.ClassyTask") -> None:
)
raise FileNotFoundError(err_msg)

def on_phase_end(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
) -> None:
def on_phase_end(self, task: "tasks.ClassyTask") -> None:
"""Checkpoint the task every checkpoint_period phases.
We do not necessarily checkpoint the task at the end of every phase.
Expand Down
4 changes: 1 addition & 3 deletions classy_vision/hooks/classy_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,7 @@ def on_step(self, task: "tasks.ClassyTask") -> None:
pass

@abstractmethod
def on_phase_end(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
) -> None:
def on_phase_end(self, task: "tasks.ClassyTask") -> None:
"""Called at the end of each phase (epoch)."""
pass

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def on_phase_start(self, task: ClassyTask) -> None:
# restore the right state depending on the phase type
self.set_model_state(task, use_ema=not task.train)

def on_phase_end(self, task: ClassyTask, local_variables: Dict[str, Any]) -> None:
def on_phase_end(self, task: ClassyTask) -> None:
if task.train:
# save the current model state since this will be overwritten by the ema
# state in the test phase
Expand Down
4 changes: 1 addition & 3 deletions classy_vision/hooks/loss_lr_meter_logging_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,7 @@ def __init__(self, log_freq: Optional[int] = None) -> None:
super().__init__()
self.log_freq: Optional[int] = log_freq

def on_phase_end(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
) -> None:
def on_phase_end(self, task: "tasks.ClassyTask") -> None:
"""
Log the loss, optimizer LR, and meters for the phase.
"""
Expand Down
4 changes: 1 addition & 3 deletions classy_vision/hooks/progress_bar_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,7 @@ def on_step(self, task: "tasks.ClassyTask") -> None:
self.batches += 1
self.progress_bar.update(min(self.batches, self.bar_size))

def on_phase_end(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
) -> None:
def on_phase_end(self, task: "tasks.ClassyTask") -> None:
"""Clear the progress bar at the end of the phase."""
if is_master() and self.progress_bar is not None:
self.progress_bar.finish()
4 changes: 1 addition & 3 deletions classy_vision/hooks/tensorboard_plot_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,7 @@ def on_step(self, task: "tasks.ClassyTask") -> None:
self.wall_times.append(time.time())
self.num_steps_global.append(task.num_updates)

def on_phase_end(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
) -> None:
def on_phase_end(self, task: "tasks.ClassyTask") -> None:
"""Add the losses and learning rates to tensorboard."""
if self.learning_rates is None:
logging.warning("learning_rates is not initialized")
Expand Down
4 changes: 1 addition & 3 deletions classy_vision/hooks/time_metrics_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,7 @@ def on_step(self, task: "tasks.ClassyTask") -> None:
if batches and batches % self.log_freq == 0:
self._log_performance_metrics(task)

def on_phase_end(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
) -> None:
def on_phase_end(self, task: "tasks.ClassyTask") -> None:
"""
Log metrics at the end of a phase if log_freq is None.
"""
Expand Down
4 changes: 1 addition & 3 deletions classy_vision/hooks/visdom_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,7 @@ def __init__(
self.metrics: Dict = {}
self.visdom: Visdom = Visdom(self.server, self.port)

def on_phase_end(
self, task: "tasks.ClassyTask", local_variables: Dict[str, Any]
) -> None:
def on_phase_end(self, task: "tasks.ClassyTask") -> None:
"""
Plot the metrics on visdom.
"""
Expand Down
5 changes: 3 additions & 2 deletions classy_vision/tasks/classification_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,7 +868,7 @@ def on_phase_start(self):

self.phase_start_time_train = time.perf_counter()

def on_phase_end(self, local_variables):
def on_phase_end(self):
self.log_phase_end("train")

logging.info("Syncing meters on phase end...")
Expand All @@ -877,7 +877,8 @@ def on_phase_end(self, local_variables):
logging.info("...meters synced")
barrier()

self.run_hooks(local_variables, ClassyHookFunctions.on_phase_end.name)
for hook in self.hooks:
hook.on_phase_end(self)
self.perf_log = []

self.log_phase_end("total")
Expand Down
2 changes: 1 addition & 1 deletion classy_vision/tasks/classy_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def on_phase_start(self):
pass

@abstractmethod
def on_phase_end(self, local_variables):
def on_phase_end(self):
"""
Epoch end.
Expand Down
4 changes: 1 addition & 3 deletions classy_vision/trainer/classy_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,6 @@ def train(self, task: ClassyTask):
# this helps catch hangs which would have happened elsewhere
barrier()

local_variables = {}

task.on_start()
while not task.done_training():
task.on_phase_start()
Expand All @@ -80,5 +78,5 @@ def train(self, task: ClassyTask):
task.step(self.use_gpu)
except StopIteration:
break
task.on_phase_end(local_variables)
task.on_phase_end()
task.on_end()
10 changes: 4 additions & 6 deletions classy_vision/trainer/elastic_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,6 @@ def train(self, task):
)
state = self._ClassyElasticState(task, self.input_args)

local_variables = {}

state.advance_to_next_phase = True

def elastic_train_step(orig_state):
Expand All @@ -82,13 +80,13 @@ def elastic_train_step(orig_state):
state.run_start_hooks = False
return state, self._ClassyWorkerStats(None)

return self._run_step(orig_state, local_variables, self.use_gpu)
return self._run_step(orig_state, self.use_gpu)

torchelastic.train(self.elastic_coordinator, elastic_train_step, state)

task.on_end()

def _run_step(self, state, local_variables, use_gpu):
def _run_step(self, state, use_gpu):
# Check for training complete but only terminate when the last phase is done
if state.task.done_training() and state.advance_to_next_phase:
raise StopIteration
Expand All @@ -112,10 +110,10 @@ def _run_step(self, state, local_variables, use_gpu):

if state.advance_to_next_phase:
self.elastic_coordinator.barrier()
state.task.on_phase_end(local_variables)
state.task.on_phase_end()

progress_rate = None # using None to signal 'unknown'
perf_stats = local_variables.get("perf_stats", None)
perf_stats = getattr(state.task, "perf_stats", None)
if perf_stats is not None:
batch_time = perf_stats._cuda_stats["train_step_total"].smoothed_value
if batch_time is not None and batch_time > 0.0:
Expand Down
16 changes: 6 additions & 10 deletions test/hooks_checkpoint_hook_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ def test_state_checkpointing(self) -> None:
task = build_task(config)
task.prepare()

local_variables = {}
checkpoint_folder = self.base_dir + "/checkpoint_end_test/"
input_args = {"foo": "bar"}

Expand All @@ -48,7 +47,7 @@ def test_state_checkpointing(self) -> None:
checkpoint_hook.on_start(task)
# call the on end phase function
with self.assertRaises(AssertionError):
checkpoint_hook.on_phase_end(task, local_variables)
checkpoint_hook.on_phase_end(task)
# try loading a non-existent checkpoint
checkpoint = load_checkpoint(checkpoint_folder)
self.assertIsNone(checkpoint)
Expand All @@ -60,13 +59,13 @@ def test_state_checkpointing(self) -> None:
# Phase_type is test, expect no checkpoint
task.train = False
# call the on end phase function
checkpoint_hook.on_phase_end(task, local_variables)
checkpoint_hook.on_phase_end(task)
checkpoint = load_checkpoint(checkpoint_folder)
self.assertIsNone(checkpoint)

task.train = True
# call the on end phase function
checkpoint_hook.on_phase_end(task, local_variables)
checkpoint_hook.on_phase_end(task)
# model should be checkpointed. load and compare
checkpoint = load_checkpoint(checkpoint_folder)
self.assertIsNotNone(checkpoint)
Expand All @@ -84,7 +83,6 @@ def test_checkpoint_period(self) -> None:
task = build_task(config)
task.prepare()

local_variables = {}
checkpoint_folder = self.base_dir + "/checkpoint_end_test/"
checkpoint_period = 10

Expand All @@ -110,7 +108,7 @@ def test_checkpoint_period(self) -> None:
while valid_phase_count < checkpoint_period - 1:
task.train = count % 2 == 0
# call the on end phase function
checkpoint_hook.on_phase_end(task, local_variables)
checkpoint_hook.on_phase_end(task)
checkpoint = load_checkpoint(checkpoint_folder)
self.assertIsNone(checkpoint)
valid_phase_count += 1 if task.phase_type in phase_types else 0
Expand All @@ -119,7 +117,7 @@ def test_checkpoint_period(self) -> None:
# create a phase which is in phase_types
task.train = True
# call the on end phase function
checkpoint_hook.on_phase_end(task, local_variables)
checkpoint_hook.on_phase_end(task)
# model should be checkpointed. load and compare
checkpoint = load_checkpoint(checkpoint_folder)
self.assertIsNotNone(checkpoint)
Expand All @@ -137,13 +135,11 @@ def test_checkpointing(self):

task.prepare(use_gpu=cuda_available)

local_variables = {}

# create a checkpoint hook
checkpoint_hook = CheckpointHook(checkpoint_folder, {}, phase_types=["train"])

# call the on end phase function
checkpoint_hook.on_phase_end(task, local_variables)
checkpoint_hook.on_phase_end(task)

# we should be able to train a task using the checkpoint on all available
# devices
Expand Down
5 changes: 2 additions & 3 deletions test/hooks_exponential_moving_average_model_hook_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ def _map_device_string(self, device):
def _test_exponential_moving_average_hook(self, model_device, hook_device):
task = mock.MagicMock()
model = TestModel().to(device=self._map_device_string(model_device))
local_variables = {}
task.base_model = model
task.train = True
decay = 0.5
Expand All @@ -54,14 +53,14 @@ def _test_exponential_moving_average_hook(self, model_device, hook_device):
fc_weight = model.fc.weight.clone()
for _ in range(num_updates):
exponential_moving_average_hook.on_step(task)
exponential_moving_average_hook.on_phase_end(task, local_variables)
exponential_moving_average_hook.on_phase_end(task)
# the model weights shouldn't have changed
self.assertTrue(torch.allclose(model.fc.weight, fc_weight))

# simulate a test phase now
task.train = False
exponential_moving_average_hook.on_phase_start(task)
exponential_moving_average_hook.on_phase_end(task, local_variables)
exponential_moving_average_hook.on_phase_end(task)

# the model weights should be updated to the ema weights
self.assertTrue(
Expand Down
3 changes: 1 addition & 2 deletions test/hooks_loss_lr_meter_logging_hook_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ def test_logging(self, mock_get_rank: mock.MagicMock) -> None:

losses = [1.2, 2.3, 3.4, 4.5]

local_variables = {}
task.phase_idx = 0

for log_freq in [5, None]:
Expand All @@ -61,7 +60,7 @@ def test_logging(self, mock_get_rank: mock.MagicMock) -> None:
mock_fn.assert_not_called()
mock_lr_fn.assert_not_called()

loss_lr_meter_hook.on_phase_end(task, local_variables)
loss_lr_meter_hook.on_phase_end(task)
mock_fn.assert_called_with(task)
if task.train:
mock_lr_fn.assert_called_with(task)
Expand Down
7 changes: 3 additions & 4 deletions test/hooks_time_metrics_hook_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def test_time_metrics(
mock_get_rank.return_value = rank

mock_report_str.return_value = ""
local_variables = {}

for log_freq, train in product([5, None], [True, False]):
# create a time metrics hook
Expand All @@ -53,7 +52,7 @@ def test_time_metrics(

# test that the code doesn't raise an exception if losses is empty
try:
time_metrics_hook.on_phase_end(task, local_variables)
time_metrics_hook.on_phase_end(task)
except Exception as e:
self.fail("Received Exception when losses is []: {}".format(e))

Expand All @@ -73,7 +72,7 @@ def test_time_metrics(
continue
mock_fn.assert_not_called()

time_metrics_hook.on_phase_end(task, local_variables)
time_metrics_hook.on_phase_end(task)
mock_fn.assert_called_with(task)

task.losses = [0.23, 0.45, 0.34, 0.67]
Expand Down Expand Up @@ -116,7 +115,7 @@ def test_time_metrics(
time_metrics_hook_new = TimeMetricsHook()

with self.assertLogs() as log_watcher:
time_metrics_hook_new.on_phase_end(task, local_variables)
time_metrics_hook_new.on_phase_end(task)

self.assertEqual(len(log_watcher.output), 2)
self.assertTrue(
Expand Down
8 changes: 3 additions & 5 deletions test/manual/hooks_progress_bar_hook_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ def test_progress_bar(

mock_is_master.return_value = True

local_variables = {}

task = get_test_classy_task()
task.prepare()
task.advance_phase()
Expand Down Expand Up @@ -60,7 +58,7 @@ def test_progress_bar(
mock_progress_bar.update.reset_mock()

# finish should be called on the progress bar
progress_bar_hook.on_phase_end(task, local_variables)
progress_bar_hook.on_phase_end(task)
mock_progress_bar.finish.assert_called_once_with()
mock_progress_bar.finish.reset_mock()

Expand All @@ -69,7 +67,7 @@ def test_progress_bar(
progress_bar_hook = ProgressBarHook()
try:
progress_bar_hook.on_step(task)
progress_bar_hook.on_phase_end(task, local_variables)
progress_bar_hook.on_phase_end(task)
except Exception as e:
self.fail(
"Received Exception when on_phase_start() isn't called: {}".format(e)
Expand All @@ -82,7 +80,7 @@ def test_progress_bar(
try:
progress_bar_hook.on_phase_start(task)
progress_bar_hook.on_step(task)
progress_bar_hook.on_phase_end(task, local_variables)
progress_bar_hook.on_phase_end(task)
except Exception as e:
self.fail("Received Exception when is_master() is False: {}".format(e))
self.assertIsNone(progress_bar_hook.progress_bar)
Expand Down
6 changes: 2 additions & 4 deletions test/manual/hooks_tensorboard_plot_hook_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,6 @@ def test_writer(self, mock_is_master_func: mock.MagicMock) -> None:

losses = [1.23, 4.45, 12.3, 3.4]

local_variables = {}

summary_writer = SummaryWriter(self.base_dir)
# create a spy on top of summary_writer
summary_writer = mock.MagicMock(wraps=summary_writer)
Expand All @@ -74,7 +72,7 @@ def test_writer(self, mock_is_master_func: mock.MagicMock) -> None:
# the writer if on_phase_start() is not called for initialization
# if on_phase_end() is called.
with self.assertLogs() as log_watcher:
tensorboard_plot_hook.on_phase_end(task, local_variables)
tensorboard_plot_hook.on_phase_end(task)

self.assertTrue(
len(log_watcher.records) == 1
Expand All @@ -90,7 +88,7 @@ def test_writer(self, mock_is_master_func: mock.MagicMock) -> None:
task.losses.append(loss)
tensorboard_plot_hook.on_step(task)

tensorboard_plot_hook.on_phase_end(task, local_variables)
tensorboard_plot_hook.on_phase_end(task)

if master:
# add_scalar() should have been called with the right scalars
Expand Down
Loading

0 comments on commit 636740b

Please sign in to comment.