Skip to content

Commit

Permalink
Remove uses of JAX-internal test utilites from Flax.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 481203353
  • Loading branch information
hawkinsp authored and Flax Authors committed Oct 14, 2022
1 parent 16a95f0 commit e0de630
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 36 deletions.
64 changes: 34 additions & 30 deletions tests/checkpoints_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from flax.training import checkpoints
import jax
from jax import numpy as jnp
from jax._src import test_util as jtu
import numpy as np
from tensorflow.io import gfile

Expand All @@ -38,6 +37,11 @@
PyTree = Any


def check_eq(xs, ys):
return jax.tree_util.tree_all(
jax.tree_util.tree_map(np.testing.assert_allclose, xs, ys))


def shuffle(l):
"""Functional shuffle."""
l = copy.copy(l)
Expand Down Expand Up @@ -99,42 +103,42 @@ def test_save_restore_checkpoints(self):
'b': np.array([2, 2, 2], np.int32)}
new_object = checkpoints.restore_checkpoint(
tmp_dir, test_object0, prefix='test_')
jtu.check_eq(new_object, test_object0)
check_eq(new_object, test_object0)
# Create leftover temporary checkpoint, which should be ignored.
gfile.GFile(os.path.join(tmp_dir, 'test_tmp'), 'w')
checkpoints.save_checkpoint(
tmp_dir, test_object1, 0, prefix='test_', keep=1)
self.assertIn('test_0', os.listdir(tmp_dir))
new_object = checkpoints.restore_checkpoint(
tmp_dir, test_object0, prefix='test_')
jtu.check_eq(new_object, test_object1)
check_eq(new_object, test_object1)
checkpoints.save_checkpoint(
tmp_dir, test_object1, 1, prefix='test_', keep=1)
checkpoints.save_checkpoint(
tmp_dir, test_object2, 2, prefix='test_', keep=1)
new_object = checkpoints.restore_checkpoint(
tmp_dir, test_object0, prefix='test_')
jtu.check_eq(new_object, test_object2)
check_eq(new_object, test_object2)
checkpoints.save_checkpoint(
tmp_dir, test_object2, 3, prefix='test_', keep=2)
checkpoints.save_checkpoint(
tmp_dir, test_object1, 4, prefix='test_', keep=2)
new_object = checkpoints.restore_checkpoint(
tmp_dir, test_object0, prefix='test_')
jtu.check_eq(new_object, test_object1)
check_eq(new_object, test_object1)
new_object = checkpoints.restore_checkpoint(
tmp_dir, test_object0, step=3, prefix='test_')
jtu.check_eq(new_object, test_object2)
check_eq(new_object, test_object2)
# Restore a specific path.
new_object = checkpoints.restore_checkpoint(
os.path.join(tmp_dir, 'test_3'), test_object0)
jtu.check_eq(new_object, test_object2)
check_eq(new_object, test_object2)
# If a specific path is specified, but it does not exist, the same behavior
# as when a directory is empty should apply: the target is returned
# unchanged.
new_object = checkpoints.restore_checkpoint(
os.path.join(tmp_dir, 'test_not_there'), test_object0)
jtu.check_eq(new_object, test_object0)
check_eq(new_object, test_object0)
with self.assertRaises(ValueError):
checkpoints.restore_checkpoint(
tmp_dir, test_object0, step=5, prefix='test_')
Expand All @@ -153,25 +157,25 @@ def test_overwrite_checkpoints(self):
checkpoints.save_checkpoint(tmp_dir, test_object, 0, keep=1, overwrite=True)

new_object = checkpoints.restore_checkpoint(tmp_dir, test_object0)
jtu.check_eq(new_object, test_object)
check_eq(new_object, test_object)
checkpoints.save_checkpoint(
tmp_dir, test_object0, 2, keep=1, overwrite=True)
new_object = checkpoints.restore_checkpoint(tmp_dir, test_object)
jtu.check_eq(new_object, test_object0)
check_eq(new_object, test_object0)
with self.assertRaises(errors.InvalidCheckpointError):
checkpoints.save_checkpoint(tmp_dir, test_object, 1, keep=1)
checkpoints.save_checkpoint(tmp_dir, test_object, 1, keep=1, overwrite=True)
new_object = checkpoints.restore_checkpoint(tmp_dir, test_object0)
jtu.check_eq(new_object, test_object)
check_eq(new_object, test_object)
os.chdir(os.path.dirname(tmp_dir))
rel_tmp_dir = './' + os.path.basename(tmp_dir)
checkpoints.save_checkpoint(rel_tmp_dir, test_object, 3, keep=1)
new_object = checkpoints.restore_checkpoint(rel_tmp_dir, test_object0)
jtu.check_eq(new_object, test_object)
check_eq(new_object, test_object)
non_norm_dir_path = tmp_dir + '//'
checkpoints.save_checkpoint(non_norm_dir_path, test_object, 4, keep=1)
new_object = checkpoints.restore_checkpoint(non_norm_dir_path, test_object0)
jtu.check_eq(new_object, test_object)
check_eq(new_object, test_object)

@parameterized.parameters({'keep_every_n_steps': None},
{'keep_every_n_steps': 7})
Expand All @@ -196,7 +200,7 @@ def test_keep(self, keep_every_n_steps):
step - last_checkpoint) >= keep_every_n_steps):
restored = checkpoints.restore_checkpoint(
tmp_dir, target=None, step=step)
jtu.check_eq(restored, test_object)
check_eq(restored, test_object)
last_checkpoint = step
else:
with self.assertRaises(ValueError):
Expand All @@ -217,7 +221,7 @@ def test_save_restore_checkpoints_w_float_steps(self):
self.assertIn('test_0.0', os.listdir(tmp_dir))
new_object = checkpoints.restore_checkpoint(
tmp_dir, test_object0, prefix='test_')
jtu.check_eq(new_object, test_object1)
check_eq(new_object, test_object1)
checkpoints.save_checkpoint(
tmp_dir, test_object1, 2.0, prefix='test_', keep=1)
with self.assertRaises(errors.InvalidCheckpointError):
Expand All @@ -227,7 +231,7 @@ def test_save_restore_checkpoints_w_float_steps(self):
tmp_dir, test_object2, 3.0, prefix='test_', keep=2)
self.assertIn('test_3.0', os.listdir(tmp_dir))
self.assertIn('test_2.0', os.listdir(tmp_dir))
jtu.check_eq(new_object, test_object1)
check_eq(new_object, test_object1)

def test_save_restore_checkpoints_target_none(self):
tmp_dir = self.create_tempdir().full_path
Expand All @@ -236,36 +240,36 @@ def test_save_restore_checkpoints_target_none(self):
# Target pytree is a dictionary, so it's equal to a restored state_dict.
checkpoints.save_checkpoint(tmp_dir, test_object0, 0)
new_object = checkpoints.restore_checkpoint(tmp_dir, target=None)
jtu.check_eq(new_object, test_object0)
check_eq(new_object, test_object0)
# Target pytree it's a tuple, check the expected state_dict is recovered.
test_object1 = (np.array([0, 0, 0], np.int32),
np.array([1, 1, 1], np.int32))
checkpoints.save_checkpoint(tmp_dir, test_object1, 1)
new_object = checkpoints.restore_checkpoint(tmp_dir, target=None)
expected_new_object = {str(k): v for k, v in enumerate(test_object1)}
jtu.check_eq(new_object, expected_new_object)
check_eq(new_object, expected_new_object)

def test_save_restore_checkpoints_target_singular(self):
tmp_dir = self.create_tempdir().full_path
test_object0 = np.array([0, 0, 0], np.int32)
test_object1 = np.array([1, 1, 1], np.int32)
checkpoints.save_checkpoint(tmp_dir, test_object1, 0)
new_object = checkpoints.restore_checkpoint(tmp_dir, target=None)
jtu.check_eq(new_object, test_object1)
check_eq(new_object, test_object1)
checkpoints.save_checkpoint(tmp_dir, test_object0, 1)
new_object = checkpoints.restore_checkpoint(tmp_dir, target=test_object1)
jtu.check_eq(new_object, test_object0)
check_eq(new_object, test_object0)

def test_save_restore_checkpoints_target_empty(self):
tmp_dir = self.create_tempdir().full_path
test_object0 = {}
test_object1 = []
checkpoints.save_checkpoint(tmp_dir, test_object1, 0)
new_object = checkpoints.restore_checkpoint(tmp_dir, target=None)
jtu.check_eq(new_object, test_object0)
check_eq(new_object, test_object0)
checkpoints.save_checkpoint(tmp_dir, test_object0, 1)
new_object = checkpoints.restore_checkpoint(tmp_dir, target=test_object1)
jtu.check_eq(new_object, test_object1)
check_eq(new_object, test_object1)

def test_async_save_checkpoints(self):
tmp_dir = pathlib.Path(self.create_tempdir().full_path)
Expand All @@ -279,7 +283,7 @@ def test_async_save_checkpoints(self):
'b': np.random.normal(size=(1000, 1000))}
new_object = checkpoints.restore_checkpoint(
tmp_dir, test_object0, prefix='test_')
jtu.check_eq(new_object, test_object0)
check_eq(new_object, test_object0)
# Create leftover temporary checkpoint, which should be ignored.
gfile.GFile(os.path.join(tmp_dir, 'test_tmp'), 'w')
am = checkpoints.AsyncManager()
Expand All @@ -290,7 +294,7 @@ def test_async_save_checkpoints(self):
self.assertIn('test_0', os.listdir(tmp_dir))
new_object = checkpoints.restore_checkpoint(
tmp_dir, test_object1, prefix='test_')
jtu.check_eq(new_object, test_object1)
check_eq(new_object, test_object1)
# Check two consecutive saves happen in the right order.
checkpoints.save_checkpoint(
tmp_dir, test_object2, 1, prefix='test_', keep=1, async_manager=am)
Expand All @@ -300,24 +304,24 @@ def test_async_save_checkpoints(self):
self.assertIn('test_2', os.listdir(tmp_dir))
new_object = checkpoints.restore_checkpoint(
tmp_dir, test_object1, prefix='test_')
jtu.check_eq(new_object, test_object3)
check_eq(new_object, test_object3)

def test_last_checkpoint(self):
tmp_dir = pathlib.Path(self.create_tempdir().full_path)
with gfile.GFile(os.path.join(tmp_dir, 'test_tmp'), 'w') as f:
f.write('test_tmp')
f.write('test_tmp')
gfile.makedirs(os.path.join(tmp_dir, 'test_tmp_gda'))
self.assertEqual(checkpoints.latest_checkpoint(tmp_dir, 'test_'),
None)

with gfile.GFile(os.path.join(tmp_dir, 'test_0'), 'w') as f:
f.write('test_0')
f.write('test_0')
gfile.makedirs(os.path.join(tmp_dir, 'test_0_gda'))
self.assertEqual(checkpoints.latest_checkpoint(tmp_dir, 'test_'),
os.path.join(tmp_dir, 'test_0'))

with gfile.GFile(os.path.join(tmp_dir, 'test_10'), 'w') as f:
f.write('test_10')
f.write('test_10')
self.assertEqual(checkpoints.latest_checkpoint(tmp_dir, 'test_'),
os.path.join(tmp_dir, 'test_10'))
self.assertEqual(checkpoints.latest_checkpoint(tmp_dir, 'ckpt_'),
Expand All @@ -332,15 +336,15 @@ def test_jax_array(self, jax_array_config):
test_object1 = {'a': jnp.ones(3), 'b': jnp.arange(3, 6)}
new_object = checkpoints.restore_checkpoint(
tmp_dir, test_object0, prefix='test_')
jtu.check_eq(new_object, test_object0)
check_eq(new_object, test_object0)
# Create leftover temporary checkpoint, which should be ignored.
gfile.GFile(os.path.join(tmp_dir, 'test_tmp'), 'w')
checkpoints.save_checkpoint(
tmp_dir, test_object1, 0, prefix='test_', keep=1)
self.assertIn('test_0', os.listdir(tmp_dir))
new_object = checkpoints.restore_checkpoint(
tmp_dir, test_object0, prefix='test_')
jtu.check_eq(new_object, {'a': np.ones(3), 'b': np.arange(3, 6)})
check_eq(new_object, {'a': np.ones(3), 'b': np.arange(3, 6)})


def test_convert_pre_linen(self):
Expand Down
16 changes: 10 additions & 6 deletions tests/linen/linen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

import jax
from jax import random
from jax._src import test_util as jtu
from jax.nn import initializers
import jax.numpy as jnp

Expand All @@ -32,6 +31,11 @@
jax.config.parse_flags_with_absl()


def check_eq(xs, ys):
return jax.tree_util.tree_all(
jax.tree_util.tree_map(np.testing.assert_allclose, xs, ys))


class PoolTest(parameterized.TestCase):

def test_pool_custom_reduce(self):
Expand Down Expand Up @@ -80,9 +84,9 @@ def test_avg_pool_padding_same(self, count_include_pad):
pool = lambda x: nn.avg_pool(x, (2, 2), padding="SAME", count_include_pad=count_include_pad)
y = pool(x)
if count_include_pad:
expected_y = jnp.array([10.0 / 4, 6.0 / 4, 7.0 / 4, 4.0 / 4]).reshape((1, 2, 2, 1))
expected_y = jnp.array([10.0 / 4, 6.0 / 4, 7.0 / 4, 4.0 / 4]).reshape((1, 2, 2, 1))
else:
expected_y = jnp.array([10.0 / 4, 6.0 / 2, 7.0 / 2, 4.0 / 1]).reshape((1, 2, 2, 1))
expected_y = jnp.array([10.0 / 4, 6.0 / 2, 7.0 / 2, 4.0 / 1]).reshape((1, 2, 2, 1))
np.testing.assert_allclose(y, expected_y)

def test_max_pool(self):
Expand Down Expand Up @@ -110,9 +114,9 @@ def test_avg_pool_padding_same(self, count_include_pad):
pool = lambda x: nn.avg_pool(x, (2, 2), padding="SAME", count_include_pad=count_include_pad)
y = pool(x)
if count_include_pad:
expected_y = jnp.array([10.0 / 4, 6.0 / 4, 7.0 / 4, 4.0 / 4]).reshape((1, 2, 2, 1))
expected_y = jnp.array([10.0 / 4, 6.0 / 4, 7.0 / 4, 4.0 / 4]).reshape((1, 2, 2, 1))
else:
expected_y = jnp.array([10.0 / 4, 6.0 / 2, 7.0 / 2, 4.0 / 1]).reshape((1, 2, 2, 1))
expected_y = jnp.array([10.0 / 4, 6.0 / 2, 7.0 / 2, 4.0 / 1]).reshape((1, 2, 2, 1))
np.testing.assert_allclose(y, expected_y)


Expand Down Expand Up @@ -391,7 +395,7 @@ def test_optimized_lstm_cell_matches_regular(self):
(_, y_opt), lstm_opt_params = lstm_opt.init_with_output(key2, (c0, h0), x)

np.testing.assert_allclose(y, y_opt, rtol=1e-6)
jtu.check_eq(lstm_params, lstm_opt_params)
check_eq(lstm_params, lstm_opt_params)


class IdsTest(absltest.TestCase):
Expand Down

0 comments on commit e0de630

Please sign in to comment.