Skip to content

Commit

Permalink
Multi-host/device training and Multi-device evals for the Loop api.
Browse files Browse the repository at this point in the history
Current tests pass, but will try to come up with more tests.

TODO - Multi-host evals. Need to collect from multiple hosts.

PiperOrigin-RevId: 323065836
  • Loading branch information
afrozenator authored and copybara-github committed Jul 24, 2020
1 parent 17d4471 commit 524321e
Show file tree
Hide file tree
Showing 5 changed files with 383 additions and 116 deletions.
8 changes: 8 additions & 0 deletions trax/optimizers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,14 @@ def slots(self):
def slots(self, slots):
self._slots = slots

@property
def opt_params(self):
return self._init_opt_params

@opt_params.setter
def opt_params(self, opt_params):
self._init_opt_params = opt_params

def tree_init(self, weight_tree):
"""Assembles node-local initializations into full-tree initialization.
Expand Down
2 changes: 1 addition & 1 deletion trax/supervised/mnist_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def test_train_mnist(self):
eval_at=lambda step_n: step_n % 50 == 0)

training_session.run(n_steps=1000)
self.assertEqual(training_session.current_step, 1000)
self.assertEqual(training_session.step, 1000)


def _mnist_dataset():
Expand Down
62 changes: 7 additions & 55 deletions trax/supervised/trainer_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import functools
import itertools
import os
import random
import sys
import time

Expand All @@ -32,7 +31,6 @@
import gin

import jax
import numpy
import tensorflow.compat.v2 as tf
from trax import fastmath
from trax import jaxboard
Expand Down Expand Up @@ -85,8 +83,8 @@ def __init__(self, model, loss_fn, optimizer, lr_schedule, inputs,
should_write_summaries=True,
metrics=None, checkpoint_highest=None, checkpoint_lowest=None):

self._is_chief, self._n_devices, rng = (
self._init_host_and_devices(n_devices, random_seed))
self._is_chief, _, self._n_devices, rng = (
training.init_host_and_devices(n_devices, random_seed))
self._should_save_checkpoints = should_save_checkpoints and self._is_chief
self._checkpoints_at = checkpoints_at or []
self._should_write_summaries = should_write_summaries
Expand Down Expand Up @@ -339,7 +337,7 @@ def evaluate(self, n_eval_steps):
# TODO(lukaszkaiser): both model state and parameters by default include
# the loss layer. Currently, we access the pure-model parameters by just
# indexing, [0] here. But we should make it more explicit in a better API.
weights = (self._opt_state[0][0], self._metrics_weights)
weights = (self._opt_state.weights[0], self._metrics_weights)
state = (self._model_state[0], self._metrics_state)
self.log_step('Evaluation')
train_eval_slice = itertools.islice(self._train_eval_stream, n_eval_steps)
Expand Down Expand Up @@ -434,7 +432,7 @@ def save_computation_graphs(self):
output_dir = self._output_dir
if self.n_devices > 1:
batch = _reshape_by_device(batch, self.n_devices)
weights = self._opt_state[0][0]
weights = self._opt_state.weights[0]
forward_computation = jax.xla_computation(self._model_predict_eval)(
batch, weights=weights, state=self._model_state[0],
rng=self._rngs[0])
Expand Down Expand Up @@ -470,42 +468,6 @@ def print_n_weights(self):
total_size = _nested_reduce(sum, sizes)
self.log_step('Total number of trainable weights: %d' % total_size)

def _init_host_and_devices(self, n_devices=None, random_seed=None):
"""Initializes host and device attributes for this trainer.
Args:
n_devices: Number of devices this trainer will use. If `None`, get the
number from the backend.
random_seed: Random seed as the starting point for all random numbers used
by the trainer. If `None`, calculate one from system time and host id.
Returns:
is_chief: True if this trainer has special chief responsibilities.
n_devices: The passed in value of n_devices or a computed default.
random_seed: The passed in value of random_seed or a computed default.
"""
if fastmath.backend_name() == 'jax':
host_id = jax.host_id()
host_count = jax.host_count()
else:
host_id = 0
host_count = 1
is_chief = (host_id == 0)

logging.info('Initializing hosts and devices: host_id %d, host_count %d, '
'is_chief %d', host_id, host_count, is_chief)

device_count = fastmath.device_count()
n_devices = n_devices or device_count
# TODO(lukaszkaiser): remove this restriction when possible.
if n_devices != device_count and fastmath.backend_name() == 'jax':
raise ValueError('JAX cannot work yet with n_devices != all devices: '
'%d != %d' % (n_devices, device_count))

if random_seed is None and host_count > 1:
random_seed = int(1e6 * (host_id + time.time())) % 2**32
return is_chief, n_devices, init_random_number_generators(random_seed)

def _should_save_now(self):
return self._should_save_checkpoints and self._step in self._checkpoints_at

Expand Down Expand Up @@ -587,7 +549,6 @@ def train(output_dir,
trax.TrainerState or training.Loop if use_loop is True
"""
if use_loop:
init_random_number_generators(random_seed)
n_devices = num_devices() or fastmath.device_count()

# Prepare the training task.
Expand Down Expand Up @@ -617,7 +578,9 @@ def train(output_dir,
eval_model=model(mode='eval'),
eval_tasks=[eval_task],
output_dir=output_dir,
checkpoint_at=checkpoint_at)
checkpoint_at=checkpoint_at,
n_devices=n_devices,
random_seed=random_seed)

# Train and return the loop.
loop.run(steps)
Expand Down Expand Up @@ -865,17 +828,6 @@ def load_trainer_state(output_dir, model, weights_file=None):
return trainer_state


def init_random_number_generators(seed=None):
"""Initializes random generators for Python, NumPy, TensorFlow, and JAX."""
# Seed Python random (None as seed is okay), then use it to seed the others.
random.seed(seed)
if seed is None:
seed = random.randint(0, 2**31 - 1)
numpy.random.seed(seed)
tf.random.set_seed(seed)
return jax_random.get_prng(seed)


def _reshape_by_device(x, n_devices):
"""Reshapes possibly nested x into a shape (n_devices, ...)."""
return tl.reshape_by_device(x, n_devices) # pylint: disable=protected-access
Expand Down
Loading

0 comments on commit 524321e

Please sign in to comment.