Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
204 commits
Select commit Hold shift + click to select a range
be4f48a
save outputs
NihalHarish May 29, 2020
d32d017
assert updates
NihalHarish Jun 3, 2020
8e95f12
update assert
NihalHarish Jun 3, 2020
48f45d6
cleanup
NihalHarish Jun 3, 2020
55f10d4
as_dtype:
NihalHarish Jun 3, 2020
ec82021
model outputs are now constants
NihalHarish Jun 4, 2020
666bcd4
update to test
NihalHarish Jun 4, 2020
d867a9b
update import statement
NihalHarish Jun 4, 2020
5fd3a74
tmp
NihalHarish Jun 5, 2020
11c20c6
Revert "tmp"
NihalHarish Jun 8, 2020
7f260e5
str_to_mode
NihalHarish Jun 8, 2020
345f785
add tensor
NihalHarish Jun 8, 2020
beaa68d
add tensor
NihalHarish Jun 8, 2020
ab3d5c1
add dist tensor:
NihalHarish Jun 8, 2020
61372e8
add tensor
NihalHarish Jun 8, 2020
46c5e0f
for-loop
NihalHarish Jun 8, 2020
650fd6a
fix append
NihalHarish Jun 8, 2020
42fdc3a
fix assert
NihalHarish Jun 8, 2020
16b38d1
add
NihalHarish Jun 8, 2020
9e1d2c5
model output
NihalHarish Jun 8, 2020
14d911b
rename
NihalHarish Jun 8, 2020
20d0413
add to all collections
NihalHarish Jun 9, 2020
d46ebb6
revert
NihalHarish Jun 9, 2020
960d383
add to all
NihalHarish Jun 9, 2020
67f4efc
helper fn
NihalHarish Jun 9, 2020
2df341e
helper fn
NihalHarish Jun 9, 2020
94765d2
extend returns none
NihalHarish Jun 9, 2020
9eff79b
ypred
NihalHarish Jun 9, 2020
61d94e1
ypred
NihalHarish Jun 9, 2020
07d72d3
change assert
NihalHarish Jun 9, 2020
d8a8ea9
init
NihalHarish Jun 10, 2020
f7ead88
do not match in metric
NihalHarish Jun 10, 2020
6e24ca8
update
NihalHarish Jun 10, 2020
cda4e3e
inputs
NihalHarish Jun 10, 2020
9b59d0d
id
NihalHarish Jun 10, 2020
9e5606e
save outputs
NihalHarish May 29, 2020
11ddcdd
assert updates
NihalHarish Jun 3, 2020
34d2294
update assert
NihalHarish Jun 3, 2020
f87ce01
cleanup
NihalHarish Jun 3, 2020
bbb0dc6
as_dtype:
NihalHarish Jun 3, 2020
82f0531
model outputs are now constants
NihalHarish Jun 4, 2020
4663370
update to test
NihalHarish Jun 4, 2020
c64a7a1
update import statement
NihalHarish Jun 4, 2020
15c1d61
tmp
NihalHarish Jun 5, 2020
be6186f
Revert "tmp"
NihalHarish Jun 8, 2020
ae8f96b
str_to_mode
NihalHarish Jun 8, 2020
30bd425
add tensor
NihalHarish Jun 8, 2020
1e7aa1b
add tensor
NihalHarish Jun 8, 2020
85ea95a
add dist tensor:
NihalHarish Jun 8, 2020
95b8bcc
add tensor
NihalHarish Jun 8, 2020
07fd399
for-loop
NihalHarish Jun 8, 2020
7151978
fix append
NihalHarish Jun 8, 2020
72a7256
fix assert
NihalHarish Jun 8, 2020
046d165
add
NihalHarish Jun 8, 2020
070cd6f
model output
NihalHarish Jun 8, 2020
8af4ce8
rename
NihalHarish Jun 8, 2020
1761ca2
add to all collections
NihalHarish Jun 9, 2020
6b581bf
revert
NihalHarish Jun 9, 2020
6b14ee7
add to all
NihalHarish Jun 9, 2020
5c89dff
helper fn
NihalHarish Jun 9, 2020
cc13566
helper fn
NihalHarish Jun 9, 2020
d07dd47
extend returns none
NihalHarish Jun 9, 2020
766902a
ypred
NihalHarish Jun 9, 2020
4e1b802
ypred
NihalHarish Jun 9, 2020
5782846
change assert
NihalHarish Jun 9, 2020
f745186
Merge branch 'y_pred' of https://github.com/awslabs/sagemaker-debugge…
NihalHarish Jun 13, 2020
07c6e75
init
NihalHarish Jun 10, 2020
0d8c6cb
do not match in metric
NihalHarish Jun 10, 2020
ae526c0
update
NihalHarish Jun 10, 2020
bf82f9c
inputs
NihalHarish Jun 10, 2020
101fcb2
id
NihalHarish Jun 10, 2020
cdaf7f8
Merge branch 'save_model_inputs' of https://github.com/awslabs/sagema…
NihalHarish Jun 13, 2020
bc84269
test
NihalHarish Jun 13, 2020
5091415
fuse model inputs and outputs
NihalHarish Jun 14, 2020
13ce988
set fix
NihalHarish Jun 15, 2020
460e0e0
add tests
NihalHarish Jun 15, 2020
c20cc75
update test
NihalHarish Jun 15, 2020
5766aa2
eager mode
NihalHarish Jun 15, 2020
0428d62
update tests
NihalHarish Jun 15, 2020
54ad7a5
rename fn
NihalHarish Jun 15, 2020
40ded77
remove unused imports
NihalHarish Jun 15, 2020
9ead6fa
save custom tensor fn
NihalHarish Jun 16, 2020
c9a6198
test_
NihalHarish Jun 16, 2020
7c7fbb3
revert tests
NihalHarish Jun 16, 2020
ab8d103
save custom tensor fn
NihalHarish Jun 16, 2020
63babf7
test_
NihalHarish Jun 16, 2020
9633e2e
save custom tensor
NihalHarish Jun 16, 2020
a997bfa
save custom tensor
NihalHarish Jun 16, 2020
1376045
init
NihalHarish Jun 17, 2020
05b28c5
save gradients
NihalHarish Jun 17, 2020
9ae86df
ignore smdebug metrics
NihalHarish Jun 17, 2020
c8a0844
update assert
NihalHarish Jun 17, 2020
3db6856
gradients
NihalHarish Jun 17, 2020
32affd2
save inputs
NihalHarish Jun 19, 2020
582cd6e
merge master
NihalHarish Jun 26, 2020
ccde310
checks
NihalHarish Jun 26, 2020
4e14182
change assert
NihalHarish Jun 26, 2020
a68dc3e
check if collection should be saved
NihalHarish Jun 26, 2020
712f94b
set
NihalHarish Jun 26, 2020
cdb0882
revert assert
NihalHarish Jun 26, 2020
c692d8f
revert assert
NihalHarish Jun 26, 2020
cac439d
save inputs
NihalHarish Jun 26, 2020
cd36430
change regex
NihalHarish Jun 26, 2020
60d671b
modify tests
NihalHarish Jun 26, 2020
73b5362
collection
NihalHarish Jun 26, 2020
abdc64b
save fn
NihalHarish Jun 26, 2020
027b022
move test
NihalHarish Jun 26, 2020
6c5e4c9
run only for tf2
NihalHarish Jun 26, 2020
29e1319
mark skip
NihalHarish Jun 26, 2020
9e9092b
fn rename
NihalHarish Jun 26, 2020
e97de64
rename fn
NihalHarish Jun 26, 2020
cec3e09
correct boolean logic
NihalHarish Jun 26, 2020
90a8f23
fix input output logic
NihalHarish Jun 26, 2020
06ebf84
comments
NihalHarish Jun 26, 2020
15851de
grad tape example
NihalHarish Jun 26, 2020
41ca695
save layers
NihalHarish Jun 26, 2020
af1e411
rename
NihalHarish Jun 26, 2020
8cdd13e
change boolean logic
NihalHarish Jun 26, 2020
03e4f18
bug fix
NihalHarish Jun 26, 2020
2660a76
retrigger CI
NihalHarish Jun 29, 2020
fccf7e8
fix flag
NihalHarish Jul 2, 2020
f221f74
duplicate set
NihalHarish Jul 3, 2020
480db00
pred
NihalHarish Jul 7, 2020
c0817b9
nit
NihalHarish Jul 8, 2020
cb79e19
Merge remote-tracking branch 'origin' into save_inputs
NihalHarish Jul 8, 2020
80a65c7
update
NihalHarish Jul 10, 2020
e7cb92a
rename default collection
NihalHarish Jul 15, 2020
39b65df
model inputs
NihalHarish Jul 15, 2020
ca68f77
lint
NihalHarish Jul 15, 2020
281011d
update tests
NihalHarish Jul 15, 2020
74de9c9
modify assert
NihalHarish Jul 15, 2020
6dd95d7
Merge remote-tracking branch 'origin' into save_inputs
NihalHarish Jul 15, 2020
9abe494
modify assert
NihalHarish Jul 16, 2020
33c21c0
save Layers
NihalHarish Jul 16, 2020
7bd87c8
clear saved collections after saving
NihalHarish Jul 17, 2020
651d6ea
refactor
NihalHarish Jul 17, 2020
1aaabe7
nit
NihalHarish Jul 17, 2020
6d3b733
pr comments
NihalHarish Jul 22, 2020
0f08773
save tensor api
NihalHarish Jul 22, 2020
3015cae
revert typo
NihalHarish Jul 22, 2020
cca7fea
save custom tensors
NihalHarish Jul 22, 2020
8764eb2
asserts
NihalHarish Jul 22, 2020
bbf1bf6
pr comments
NihalHarish Jul 22, 2020
a32a8d4
len
NihalHarish Jul 22, 2020
259414a
default
NihalHarish Jul 23, 2020
b1ad7a0
save smdebug logs
NihalHarish Jul 23, 2020
d3b54c3
comments
NihalHarish Jul 23, 2020
fb548a9
update
NihalHarish Jul 23, 2020
7ca2942
constants
NihalHarish Jul 27, 2020
2df55e0
Implement Save Tensor For Mxnet and Pytorch (#291)
NihalHarish Jul 28, 2020
067e724
parameterize test keras fit
NihalHarish Jul 28, 2020
49550e8
tf eager
NihalHarish Jul 28, 2020
8dfb4a9
asserts
NihalHarish Jul 22, 2020
b64e9e2
Merge branch 'save_inputs_asserts' of https://github.com/awslabs/sage…
NihalHarish Jul 28, 2020
e4c0185
merge
NihalHarish Jul 30, 2020
20e4329
updated asserts
NihalHarish Jul 30, 2020
98378e0
save tensor
NihalHarish Jul 30, 2020
9f4ce1e
load from s3 path
NihalHarish Jul 30, 2020
80b27aa
load from s3 path
NihalHarish Jul 30, 2020
09b914c
constants file
NihalHarish Jul 30, 2020
acf4990
import
NihalHarish Jul 30, 2020
4b16a76
update asserts
NihalHarish Jul 30, 2020
401b27c
update assert
NihalHarish Jul 30, 2020
0fb3d10
Merge branch 'load_dataset_from_s3' into save_inputs_asserts
NihalHarish Jul 30, 2020
4049e86
retrigger CI
NihalHarish Jul 30, 2020
7d8130e
rename variable
NihalHarish Jul 30, 2020
fba5837
check if s3 is accessible
NihalHarish Jul 30, 2020
8e613bf
merge master
NihalHarish Jul 30, 2020
6ee1a7d
Merge branch 'load_dataset_from_s3' of https://github.com/awslabs/sag…
NihalHarish Jul 30, 2020
5407187
Merge branch 'master' into load_dataset_from_s3
NihalHarish Jul 30, 2020
3501448
pythonic
NihalHarish Jul 30, 2020
1098f4b
Merge branch 'load_dataset_from_s3' into save_inputs_asserts
NihalHarish Jul 30, 2020
2816c41
sanitize s3 path
NihalHarish Jul 30, 2020
4fd8fa9
Merge branch 'load_dataset_from_s3' into save_inputs_asserts
NihalHarish Jul 30, 2020
7a8eeec
Merge remote-tracking branch 'origin' into save_inputs_asserts
NihalHarish Jul 31, 2020
08b7d95
Merge remote-tracking branch 'origin' into save_inputs_asserts
NihalHarish Jul 31, 2020
d7b12dc
save layers
NihalHarish Jul 31, 2020
3b5d750
remove json
NihalHarish Jul 31, 2020
1e66117
gradient tape tests
NihalHarish Jul 31, 2020
4e8f56f
nit
NihalHarish Jul 31, 2020
4378f8d
Revert "nit"
NihalHarish Jul 31, 2020
e004d2f
Merge remote-tracking branch 'origin' into zcc_tests_tf_2_2
NihalHarish Aug 1, 2020
b1d52a4
stronger asserts
NihalHarish Aug 1, 2020
499dfff
update json
NihalHarish Aug 2, 2020
0c9d2d7
Merge remote-tracking branch 'origin' into save_inputs_asserts
NihalHarish Aug 2, 2020
2c6ed03
Merge branch 'zcc_tests_tf_2_2' into save_inputs_asserts
NihalHarish Aug 2, 2020
9fc170d
rename enum
NihalHarish Aug 2, 2020
0933d02
change eager mode to run_eagerly
NihalHarish Aug 2, 2020
e0da231
mod assert
NihalHarish Aug 2, 2020
c091632
default
NihalHarish Aug 2, 2020
8f456f5
reduce runtime
NihalHarish Aug 4, 2020
e934b2e
speed
NihalHarish Aug 4, 2020
2a597d5
saveall
NihalHarish Aug 4, 2020
80af190
faster tests
NihalHarish Aug 4, 2020
83056d4
custom train step
NihalHarish Aug 5, 2020
097b563
mark skip
NihalHarish Aug 5, 2020
1ae0fc3
update test json config
NihalHarish Aug 6, 2020
d24fbe6
Merge branch 'master' into save_inputs_asserts
NihalHarish Aug 6, 2020
4db75e8
lint
NihalHarish Aug 6, 2020
f2c7903
update tf22 check
NihalHarish Aug 6, 2020
2a51d1a
support both tf 23 and tf 22
NihalHarish Aug 6, 2020
2a731c2
update integ tests
NihalHarish Aug 6, 2020
990618d
merge integ tests
NihalHarish Aug 6, 2020
5efbbfb
disable tb testing
NihalHarish Aug 6, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 38 additions & 23 deletions tests/tensorflow2/test_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds
from tests.constants import TEST_DATASET_S3_PATH
from tests.tensorflow2.utils import is_tf_2_2
from tests.tensorflow2.utils import is_tf_2_2, is_tf_2_3
from tests.tensorflow.utils import create_trial_fast_refresh
from tests.utils import use_s3_datasets

Expand Down Expand Up @@ -195,7 +195,7 @@ def test_keras_gradtape(out_dir, saveall):

trial = smd.create_trial(path=out_dir)
if saveall: # save losses, metrics, weights, biases
assert len(trial.tensor_names()) == 15
assert len(trial.tensor_names()) == (25 if is_tf_2_2() else 15)
assert len(trial.tensor_names(collection=CollectionKeys.BIASES)) == 2
assert len(trial.tensor_names(collection=CollectionKeys.WEIGHTS)) == 2
assert len(trial.tensor_names(collection=CollectionKeys.OPTIMIZER_VARIABLES)) == 5
Expand Down Expand Up @@ -275,7 +275,7 @@ def test_gradtape_include_regex(out_dir):
tr = create_trial_fast_refresh(out_dir)
tnames = tr.tensor_names(collection="custom_coll")

assert len(tnames) == 8
assert len(tnames) == (12 if is_tf_2_2() else 8)
for tname in tnames:
assert tr.tensor(tname).value(0) is not None

Expand Down Expand Up @@ -343,7 +343,7 @@ def test_gradtape_include_collections(out_dir):

trial = smd.create_trial(path=out_dir)
# can't save gradients in TF 2.x
assert len(trial.tensor_names()) == 15
assert len(trial.tensor_names()) == (16 if is_tf_2_2() else 15)
assert len(trial.tensor_names(collection=CollectionKeys.GRADIENTS)) == 4
assert len(trial.tensor_names(collection=CollectionKeys.OPTIMIZER_VARIABLES)) == 5
assert len(trial.tensor_names(collection=CollectionKeys.BIASES)) == 2
Expand Down Expand Up @@ -388,7 +388,7 @@ def test_gradtape_persistent(out_dir, saveall):

trial = smd.create_trial(path=out_dir)
if saveall: # save losses, metrics, weights, biases
assert len(trial.tensor_names()) == 15
assert len(trial.tensor_names()) == (25 if is_tf_2_2() else 15)
assert len(trial.tensor_names(collection=CollectionKeys.BIASES)) == 2
assert len(trial.tensor_names(collection=CollectionKeys.WEIGHTS)) == 2
assert len(trial.tensor_names(collection=CollectionKeys.OPTIMIZER_VARIABLES)) == 5
Expand All @@ -409,17 +409,24 @@ def test_keras_fit(out_dir, tf_eager_mode, saveall):
helper_keras_fit(
trial_dir=out_dir,
hook=hook,
eager=tf_eager_mode,
run_eagerly=tf_eager_mode,
steps=["train", "eval", "predict", "train"],
)

trial = smd.create_trial(path=out_dir)
# can't save gradients in TF 2.x eager mode
if saveall: # save losses, metrics, weights, biases, scalar
if tf_eager_mode:
assert len(trial.tensor_names()) == (13 if is_tf_2_2() else 14)
assert len(trial.tensor_names(collection=CollectionKeys.INPUTS)) == 0
assert len(trial.tensor_names(collection=CollectionKeys.OUTPUTS)) == 0
if is_tf_2_2():
assert len(trial.tensor_names()) == 28
else:
assert len(trial.tensor_names()) == (21 if is_tf_2_3() else 14)
assert len(trial.tensor_names(collection=CollectionKeys.INPUTS)) == (
1 if is_tf_2_2() else 0
)
assert len(trial.tensor_names(collection=CollectionKeys.OUTPUTS)) == (
2 if is_tf_2_2() else 0
)
else:
assert len(trial.tensor_names()) == 21
assert len(trial.tensor_names(collection=CollectionKeys.BIASES)) == 2
Expand All @@ -435,10 +442,12 @@ def test_keras_fit(out_dir, tf_eager_mode, saveall):
"No Optimizer Variables Should be Saved in EVAL Mode",
)
else: # save the default losses and metrics
assert len(trial.tensor_names()) == (4 if is_tf_2_2() and tf_eager_mode else 5)
assert len(trial.tensor_names()) == (
4 if (is_tf_2_2() or is_tf_2_3()) and tf_eager_mode else 5
)
assert len(trial.tensor_names(collection=CollectionKeys.LOSSES)) == 1
assert len(trial.tensor_names(collection=CollectionKeys.METRICS)) == (
2 if is_tf_2_2() and tf_eager_mode else 3
2 if (is_tf_2_2() or is_tf_2_3()) and tf_eager_mode else 3
)
for tname in trial.tensor_names():
assert trial.tensor(tname).value(0) is not None
Expand Down Expand Up @@ -510,7 +519,7 @@ def test_include_regex(out_dir, tf_eager_mode):
tnames = tr.tensor_names(collection="custom_coll")

if tf_eager_mode:
assert len(tnames) == 8
assert len(tnames) == (12 if is_tf_2_2() else 8)
else:
assert len(tnames) == 8
for tname in tnames:
Expand All @@ -534,7 +543,7 @@ def test_clash_with_tb_callback(out_dir):
add_callbacks=["tensorboard"],
)
tr = create_trial_fast_refresh(out_dir)
assert len(tr.tensor_names()) == (7 if is_tf_2_2() else 8)
assert len(tr.tensor_names()) == (7 if (is_tf_2_2() or is_tf_2_3()) else 8)


@pytest.mark.slow
Expand All @@ -560,12 +569,12 @@ def test_weights_collections(out_dir, tf_eager_mode):

trial = smd.create_trial(path=out_dir)
# can't save gradients in TF 2.x
assert len(trial.tensor_names()) == (5 if is_tf_2_2() and tf_eager_mode else 6)
assert len(trial.tensor_names()) == (5 if (is_tf_2_2() or is_tf_2_3()) and tf_eager_mode else 6)
assert len(trial.tensor_names(collection=CollectionKeys.BIASES)) == 0
assert len(trial.tensor_names(collection=CollectionKeys.WEIGHTS)) == 2
assert len(trial.tensor_names(collection=CollectionKeys.LOSSES)) == 1
assert len(trial.tensor_names(collection=CollectionKeys.METRICS)) == (
2 if is_tf_2_2() and tf_eager_mode else 3
2 if (is_tf_2_2() or is_tf_2_3()) and tf_eager_mode else 3
)


Expand Down Expand Up @@ -595,7 +604,10 @@ def test_include_collections(out_dir, tf_eager_mode):
trial = smd.create_trial(path=out_dir)
# can't save gradients in TF 2.x
if tf_eager_mode:
assert len(trial.tensor_names()) == (12 if is_tf_2_2() else 13)
if is_tf_2_2():
assert len(trial.tensor_names()) == 16
else:
assert len(trial.tensor_names()) == (12 if is_tf_2_3() else 13)
else:
assert len(trial.tensor_names()) == 18
assert len(trial.tensor_names(collection=CollectionKeys.GRADIENTS)) == 4
Expand All @@ -605,7 +617,7 @@ def test_include_collections(out_dir, tf_eager_mode):
assert len(trial.tensor_names(collection=CollectionKeys.WEIGHTS)) == 2
assert len(trial.tensor_names(collection=CollectionKeys.LOSSES)) == 1
assert len(trial.tensor_names(collection=CollectionKeys.METRICS)) == (
2 if is_tf_2_2() and tf_eager_mode else 3
2 if (is_tf_2_2() or is_tf_2_3()) and tf_eager_mode else 3
)


Expand All @@ -625,7 +637,7 @@ def test_include_only_custom_collection(out_dir, tf_eager_mode):
)

trial = smd.create_trial(path=out_dir)
assert len(trial.tensor_names()) == (8 if is_tf_2_2() and tf_eager_mode else 9)
assert len(trial.tensor_names()) == (8 if (is_tf_2_2() or is_tf_2_3()) and tf_eager_mode else 9)
assert len(trial.tensor_names(collection="custom_optimizer_variables")) == 5


Expand All @@ -640,12 +652,12 @@ def test_hook_from_json(out_dir, tf_eager_mode, monkeypatch):

trial = smd.create_trial(path=out_dir)
# can't save gradients in TF 2.x
assert len(trial.tensor_names()) == (5 if is_tf_2_2() and tf_eager_mode else 6)
assert len(trial.tensor_names()) == (5 if (is_tf_2_2() or is_tf_2_3()) and tf_eager_mode else 6)
assert len(trial.tensor_names(collection=CollectionKeys.BIASES)) == 0
assert len(trial.tensor_names(collection=CollectionKeys.WEIGHTS)) == 2
assert len(trial.tensor_names(collection=CollectionKeys.LOSSES)) == 1
assert len(trial.tensor_names(collection=CollectionKeys.METRICS)) == (
2 if is_tf_2_2() and tf_eager_mode else 3
2 if (is_tf_2_2() or is_tf_2_3()) and tf_eager_mode else 3
)


Expand All @@ -658,12 +670,15 @@ def test_keras_fit_pure_eager(out_dir, tf_eager_mode):
helper_keras_fit(trial_dir=out_dir, hook=hook, eager=tf_eager_mode, run_eagerly=True)

trial = smd.create_trial(path=out_dir)
assert len(trial.tensor_names()) == (20 if is_tf_2_2() else 21)
if is_tf_2_2():
assert len(trial.tensor_names()) == 27
else:
assert len(trial.tensor_names()) == (20 if is_tf_2_3() else 21)
assert len(trial.tensor_names(collection=CollectionKeys.BIASES)) == 2
assert len(trial.tensor_names(collection=CollectionKeys.WEIGHTS)) == 2
assert len(trial.tensor_names(collection=CollectionKeys.OPTIMIZER_VARIABLES)) == 5
assert len(trial.tensor_names(collection=CollectionKeys.INPUTS)) == 0
assert len(trial.tensor_names(collection=CollectionKeys.OUTPUTS)) == 0
assert len(trial.tensor_names(collection=CollectionKeys.INPUTS)) == (1 if is_tf_2_2() else 0)
assert len(trial.tensor_names(collection=CollectionKeys.OUTPUTS)) == (2 if is_tf_2_2() else 0)


@pytest.mark.skip # skip until aws tf update
Expand Down
37 changes: 26 additions & 11 deletions tests/tensorflow2/test_keras_mirrored.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import tensorflow_datasets as tfds
from tensorflow.python.client import device_lib
from tests.core.utils import verify_files
from tests.tensorflow2.utils import is_tf_2_2
from tests.tensorflow2.utils import is_tf_2_2, is_tf_2_3
from tests.tensorflow.utils import create_trial_fast_refresh

# First Party
Expand Down Expand Up @@ -164,11 +164,16 @@ def exhaustive_check(trial_dir, include_workers="one", eager=True):
if include_workers == "all":
assert len(tr.workers()) == strategy.num_replicas_in_sync
if eager:
assert len(tr.tensor_names()) == (
6 + 1 + 2 + 5 + 1 if is_tf_2_2() else 6 + 1 + 3 + 5 + 1
)
# 6 weights, 1 loss, 3 metrics, 5 optimizer variables for Tf 2.1, 1 scalar
# 6 weights, 1 loss, 2 metrics, 5 optimizer variables for Tf 2.2, 1 scalar
if is_tf_2_2():
assert len(tr.tensor_names()) == (6 + 1 + 2 + 5 + 1 + 6 + 2)
# 6 weights, 1 loss, 2 metrics, 5 optimizer variables, 6 gradients, 2 outputs for Tf 2.2, 1 scalar
else:
assert len(tr.tensor_names()) == (
6 + 1 + 2 + 5 + 1 if (is_tf_2_2() or is_tf_2_3()) else 6 + 1 + 3 + 5 + 1
)
# 6 weights, 1 loss, 2 metrics, 5 optimizer variables for Tf 2.3, 1 scalar
# 6 weights, 1 loss, 3 metrics, 5 optimizer variables for Tf 2.1, 1 scalar

else:
assert len(tr.tensor_names()) == (6 + 6 + 1 + 3 + strategy.num_replicas_in_sync * 3 + 5)
else:
Expand Down Expand Up @@ -232,7 +237,7 @@ def exhaustive_check(trial_dir, include_workers="one", eager=True):
assert len(tr.tensor(loss_name).steps()) == 12

metricnames = tr.tensor_names(collection=CollectionKeys.METRICS)
assert len(metricnames) == (2 if is_tf_2_2() else 3)
assert len(metricnames) == (2 if (is_tf_2_2() or is_tf_2_3()) else 3)


@pytest.mark.slow
Expand All @@ -256,8 +261,15 @@ def test_save_all(out_dir, tf_eager_mode, workers):
tr = create_trial_fast_refresh(out_dir)
print(tr.tensor_names())
if tf_eager_mode:
assert len(tr.tensor_names()) == (6 + 2 + 1 + 5 + 1 if is_tf_2_2() else 6 + 3 + 1 + 5 + 1)
# weights, metrics, losses, optimizer variables, scalar
if is_tf_2_2():
assert len(tr.tensor_names()) == (
6 + 2 + 1 + 5 + 1 + 1 + 2 + 8 + 8 if is_tf_2_2() else 6 + 3 + 1 + 5 + 1
)
# weights, metrics, losses, optimizer variables, scalar, inputs, outputs, gradients, layers
else:
assert len(tr.tensor_names()) == (
6 + 2 + 1 + 5 + 1 if is_tf_2_3() else 6 + 3 + 1 + 5 + 1
)
else:
assert (
len(tr.tensor_names())
Expand Down Expand Up @@ -366,7 +378,7 @@ def test_include_regex(out_dir, tf_eager_mode, workers):
tnames = tr.tensor_names(collection="custom_coll")

if tf_eager_mode:
assert len(tnames) == 4
assert len(tnames) == (12 if is_tf_2_2() else 4)
else:
assert len(tnames) == 4 + 3 * strategy.num_replicas_in_sync
for tname in tnames:
Expand Down Expand Up @@ -421,7 +433,10 @@ def test_clash_with_tb_callback(out_dir):
add_callbacks=["tensorboard"],
)
tr = create_trial_fast_refresh(out_dir)
assert len(tr.tensor_names()) == (10 if is_tf_2_2() else 11)
if is_tf_2_2():
assert len(tr.tensor_names()) == 16
else:
assert len(tr.tensor_names()) == (10 if is_tf_2_3() else 11)


@pytest.mark.skip
Expand Down
2 changes: 1 addition & 1 deletion tests/tensorflow2/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def is_tf_2_2():
number of tensor_names emitted by 1.
:return: bool
"""
if version.parse(tf.__version__) >= version.parse("2.2.0"):
if version.parse(tf.__version__) == version.parse("2.2.0"):
return True
return False

Expand Down
37 changes: 31 additions & 6 deletions tests/zero_code_change/test_tensorflow2_gradtape_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# Third Party
import pytest
import tensorflow.compat.v2 as tf
from tests.tensorflow2.utils import is_tf_2_2
from tests.tensorflow2.utils import is_tf_2_2, is_tf_2_3

# First Party
import smdebug.tensorflow as smd
Expand All @@ -26,7 +26,9 @@ def get_keras_data():
return (x_train, y_train), (x_test, y_test)


def helper_test_keras_v2_gradienttape(script_mode: bool = False, json_file_contents="{}"):
def helper_test_keras_v2_gradienttape(
script_mode: bool = False, json_file_contents="{}", default=False
):
""" Test the default ZCC behavior of saving losses and metrics in eager and non-eager modes."""
smd.del_hook()
tf.keras.backend.clear_session()
Expand All @@ -49,7 +51,7 @@ def helper_test_keras_v2_gradienttape(script_mode: bool = False, json_file_conte
opt = tf.keras.optimizers.RMSprop()
cce = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
train_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()
n_epochs = 2
n_epochs = 1
if script_mode:
if json_file_contents == "{}":
hook = smd.KerasHook(out_dir=sim.out_dir, export_tensorboard=True)
Expand Down Expand Up @@ -100,7 +102,7 @@ def helper_test_keras_v2_gradienttape(script_mode: bool = False, json_file_conte
print(log)
train_acc_metric.reset_states()
hook = smd.get_hook()
if not is_tf_2_2():
if not (is_tf_2_2() or is_tf_2_3()):
assert not hook # only supported on TF 2.2 and greater
return
assert hook
Expand All @@ -110,12 +112,23 @@ def helper_test_keras_v2_gradienttape(script_mode: bool = False, json_file_conte
assert len(trial.steps()) > 0, "Nothing saved at any step."
assert len(trial.tensor_names()) > 0, "Tensors were not saved."
assert len(trial.tensor_names(collection="losses")) > 0
if is_tf_2_2() and default is False:
# Inputs and Outputs are not saved with the default collection configurations.
assert len(trial.tensor_names(collection="inputs")) > 0
assert len(trial.tensor_names(collection="outputs")) > 0
assert trial.tensor_names(collection="outputs") == ["predictions"]
if "dense_layers" in json_file_contents:
# Only assert for test_keras_v2_multi_collections
# which defines this custom collection
assert len(trial.tensor_names(collection="dense_layers")) > 0
else:
assert len(trial.tensor_names(collection="dense_layers")) == 0


@pytest.mark.parametrize("script_mode", [False])
def test_keras_v2_default(script_mode):
# Test default ZCC behavior
helper_test_keras_v2_gradienttape(script_mode=script_mode)
helper_test_keras_v2_gradienttape(script_mode=script_mode, default=True)


@pytest.mark.parametrize("script_mode", [False])
Expand Down Expand Up @@ -144,6 +157,18 @@ def test_keras_v2_multi_collections(script_mode):
},
{
"CollectionName": "optimizer_variables"
},
{
"CollectionName": "outputs"
},
{
"CollectionName": "inputs"
},
{
"CollectionName": "dense_layers",
"CollectionParameters": {
"include_regex": ".*dense.*"
}
}
]
}
Expand All @@ -161,7 +186,7 @@ def test_keras_v2_save_all(script_mode):
"S3OutputPath": "s3://sagemaker-test",
"LocalPath": "/opt/ml/output/tensors",
"HookParameters" : {
"save_steps": "0,1,2,3",
"save_steps": "0",
"save_all": true
}
}
Expand Down
Loading