Skip to content

Commit

Permalink
[pmap no rank reduce cleanup]: Prepare for flipping the
Browse files Browse the repository at this point in the history
jax_pmap_no_rank_reduction flag. This flag slows down
utils.get_from_first_device which must perform rank reduction for each
step which makes this test timeout. This helper was only needed because
the trainstate was not properly marked as replicated, so we can just
do that instead and avoid the timeout.

PiperOrigin-RevId: 667607592
Change-Id: If0a1756d38f4e1741495e8ed9d1e7cc2bb3695d3
  • Loading branch information
pschuh authored and copybara-github committed Aug 26, 2024
1 parent d9c4319 commit 8c2a8c8
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 24 deletions.
5 changes: 2 additions & 3 deletions acme/agents/jax/bc/agent_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from acme.jax import types as jax_types
from acme.jax import utils
from acme.testing import fakes
import chex
import haiku as hk
import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -103,7 +102,7 @@ class BCTest(parameterized.TestCase):
('peerbc',)
)
def test_continuous_actions(self, loss_name):
with chex.fake_pmap_and_jit():
with jax.disable_jit():
num_sgd_steps_per_step = 1
num_steps = 5

Expand Down Expand Up @@ -145,7 +144,7 @@ def test_continuous_actions(self, loss_name):
('logp',),
('rcal',))
def test_discrete_actions(self, loss_name):
with chex.fake_pmap_and_jit():
with jax.disable_jit():

num_sgd_steps_per_step = 1
num_steps = 5
Expand Down
44 changes: 28 additions & 16 deletions acme/agents/jax/bc/learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,20 +150,32 @@ def sgd_step(
# Split the input batch to `num_sgd_steps_per_step` minibatches in order
# to achieve better performance on accelerators.
sgd_step = utils.process_multiple_batches(sgd_step, num_sgd_steps_per_step)
self._sgd_step = jax.pmap(sgd_step, axis_name=_PMAP_AXIS_NAME)

random_key, init_key = jax.random.split(random_key)
policy_params = networks.policy_network.init(init_key)
optimizer_state = optimizer.init(policy_params)

# Create initial state.
state = TrainingState(
optimizer_state=optimizer_state,
policy_params=policy_params,
key=random_key,
steps=0,
self._sgd_step = jax.pmap(
sgd_step,
axis_name=_PMAP_AXIS_NAME,
in_axes=(None, 0),
out_axes=(None, 0),
)
self._state = utils.replicate_in_all_devices(state)

def init_fn(random_key):
random_key, init_key = jax.random.split(random_key)
policy_params = networks.policy_network.init(init_key)
optimizer_state = optimizer.init(policy_params)

# Create initial state.
state = TrainingState(
optimizer_state=optimizer_state,
policy_params=policy_params,
key=random_key,
steps=0,
)
return state

state = jax.pmap(init_fn, out_axes=None)(
utils.replicate_in_all_devices(random_key)
)
self._state = state
self._state_sharding = jax.tree.map(lambda x: x.sharding, state)

self._timestamp = None

Expand All @@ -188,13 +200,13 @@ def step(self):

def get_variables(self, names: List[str]) -> List[networks_lib.Params]:
variables = {
'policy': utils.get_from_first_device(self._state.policy_params),
'policy': self._state.policy_params,
}
return [variables[name] for name in names]

def save(self) -> TrainingState:
# Serialize only the first replica of parameters and optimizer state.
return jax.tree.map(utils.get_from_first_device, self._state)
return self._state

def restore(self, state: TrainingState):
self._state = utils.replicate_in_all_devices(state)
self._state = jax.device_put(state, self._state_sharding)
3 changes: 1 addition & 2 deletions acme/agents/jax/mbop/agent_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from acme.agents.jax.mbop import networks as mbop_networks
from acme.testing import fakes
from acme.utils import loggers
import chex
import jax
import optax
import rlds
Expand All @@ -34,7 +33,7 @@
class MBOPTest(absltest.TestCase):

def test_learner(self):
with chex.fake_pmap_and_jit():
with jax.disable_jit():
num_sgd_steps_per_step = 1
num_steps = 5
num_networks = 7
Expand Down
4 changes: 1 addition & 3 deletions acme/agents/jax/mpo/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
from acme.jax import variable_utils
from acme.utils import counting
from acme.utils import loggers
import chex
import jax
import optax
import reverb
Expand Down Expand Up @@ -162,8 +161,7 @@ def make_learner(self,
'learner',
steps_key=counter.get_steps_key() if counter else 'learner_steps')

with chex.fake_pmap_and_jit(not self.config.jit_learner,
not self.config.jit_learner):
with jax.disable_jit(not self.config.jit_learner):
learner = learning.MPOLearner(
iterator=dataset,
networks=networks,
Expand Down

0 comments on commit 8c2a8c8

Please sign in to comment.