Skip to content

Commit

Permalink
Rolling back [#3104](google/flax#3104) because internal tests are bre…
Browse files Browse the repository at this point in the history
…aking.

PiperOrigin-RevId: 537092577
  • Loading branch information
chiamp authored and copybara-github committed Jun 1, 2023
1 parent f7d73e3 commit fb4b751
Showing 1 changed file with 11 additions and 19 deletions.
30 changes: 11 additions & 19 deletions vmoe/train/train_state_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from absl.testing import absltest
import chex
import flax
import numpy as np
import optax
from vmoe.train import train_state
Expand All @@ -25,26 +24,19 @@ class TrainStateTreeAxisResourcesTest(absltest.TestCase):
def test_apply_gradients_and_compute_global_norms(self):
state = train_state.TrainState.create(
apply_fn=lambda x: x,
params={'w': np.asarray((1.0,), dtype=np.float32)},
params={'w': np.asarray((1.,), dtype=np.float32)},
tx=optax.sgd(0.5),
rngs={},
)
grads = flax.core.freeze({'w': np.asarray((1.0,), dtype=np.float32)})
rngs={})
grads = {'w': np.asarray((1.,), dtype=np.float32)}
new_state, global_norms = state.apply_gradients_and_compute_global_norms(
grads, rngs={}
)
chex.assert_trees_all_close(
new_state.params,
flax.core.freeze({'w': np.asarray((0.5,), dtype=np.float32)}),
)
chex.assert_trees_all_close(
global_norms,
{
'grads': np.asarray((1.0,), dtype=np.float32),
'params': np.asarray((0.5,), dtype=np.float32),
'updates': np.asarray((0.5,), dtype=np.float32),
},
)
grads, rngs={})
chex.assert_trees_all_close(new_state.params,
{'w': np.asarray((.5,), dtype=np.float32)})
chex.assert_trees_all_close(global_norms, {
'grads': np.asarray((1.,), dtype=np.float32),
'params': np.asarray((.5,), dtype=np.float32),
'updates': np.asarray((.5,), dtype=np.float32),
})


if __name__ == '__main__':
Expand Down

0 comments on commit fb4b751

Please sign in to comment.