-
Notifications
You must be signed in to change notification settings - Fork 82
TF 2.x: Support for keras to estimator #268
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
105c995
b70335e
07473f6
268743a
8c245d2
25083c5
f4fefc4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -716,6 +716,11 @@ def run(*args, **kwargs): | |
| # at this point we need all collections to be ready | ||
| # this may not be the case at creation of hook | ||
| # as user's code after hook might add collections | ||
| self.collection_manager.get(CollectionKeys.WEIGHTS).include( | ||
| "^weights/.*/((?!bias).)*$" | ||
| ) | ||
| self.collection_manager.get(CollectionKeys.LOSSES).include(".*loss.*") | ||
| self.collection_manager.get(CollectionKeys.GRADIENTS).include("^gradient") | ||
|
Comment on lines
+719
to
+723
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It doesn't seem like we have tested the regex patterns themselves.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. copied them from pt regex. do we have tests for regex? |
||
| self._prepare_collections() | ||
| self.prepared_collections = True | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,61 @@ | ||
| # Standard Library | ||
| # Third Party | ||
| import pytest | ||
| import tensorflow.compat.v2 as tf | ||
| from tests.zero_code_change.tf_utils import get_estimator, get_input_fns | ||
|
|
||
| # First Party | ||
| import smdebug.tensorflow as smd | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("saveall", [True, False]) | ||
| def test_estimator(out_dir, tf_eager_mode, saveall): | ||
| """ Works as intended. """ | ||
| if tf_eager_mode is False: | ||
| tf.compat.v1.disable_eager_execution() | ||
| tf.compat.v1.reset_default_graph() | ||
| tf.keras.backend.clear_session() | ||
| mnist_classifier = get_estimator() | ||
| train_input_fn, eval_input_fn = get_input_fns() | ||
|
|
||
| # Train and evaluate | ||
| train_steps, eval_steps = 8, 2 | ||
| hook = smd.EstimatorHook(out_dir=out_dir, save_all=saveall) | ||
| hook.set_mode(mode=smd.modes.TRAIN) | ||
| mnist_classifier.train(input_fn=train_input_fn, steps=train_steps, hooks=[hook]) | ||
| hook.set_mode(mode=smd.modes.EVAL) | ||
| mnist_classifier.evaluate(input_fn=eval_input_fn, steps=eval_steps, hooks=[hook]) | ||
|
|
||
| # Check that hook created and tensors saved | ||
| trial = smd.create_trial(path=out_dir) | ||
| tnames = trial.tensor_names() | ||
| assert len(trial.steps()) > 0 | ||
| if saveall: | ||
| assert len(tnames) >= 301 | ||
| else: | ||
| assert len(tnames) == 1 | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("saveall", [True, False]) | ||
| def test_linear_classifier(out_dir, tf_eager_mode, saveall): | ||
| """ Works as intended. """ | ||
| if tf_eager_mode is False: | ||
| tf.compat.v1.disable_eager_execution() | ||
| tf.compat.v1.reset_default_graph() | ||
| tf.keras.backend.clear_session() | ||
| train_input_fn, eval_input_fn = get_input_fns() | ||
| x_feature = tf.feature_column.numeric_column("x", shape=(28, 28)) | ||
| estimator = tf.estimator.LinearClassifier( | ||
| feature_columns=[x_feature], model_dir="/tmp/mnist_linear_classifier", n_classes=10 | ||
| ) | ||
| hook = smd.EstimatorHook(out_dir=out_dir, save_all=saveall) | ||
| estimator.train(input_fn=train_input_fn, steps=10, hooks=[hook]) | ||
|
|
||
| # Check that hook created and tensors saved | ||
| trial = smd.create_trial(path=out_dir) | ||
| tnames = trial.tensor_names() | ||
| assert len(trial.steps()) > 0 | ||
| if saveall: | ||
| assert len(tnames) >= 224 | ||
| else: | ||
| assert len(tnames) == 2 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,7 +5,6 @@ | |
| import numpy as np | ||
| import tensorflow.compat.v1 as tf | ||
| import tensorflow_datasets as tfds | ||
| from tensorflow.examples.tutorials.mnist import input_data | ||
|
|
||
| tfds.disable_progress_bar() | ||
|
|
||
|
|
@@ -232,6 +231,8 @@ def neural_net(x): | |
|
|
||
|
|
||
| def get_data() -> "tf.contrib.learn.python.learn.datasets.base.Datasets": | ||
| from tensorflow.examples.tutorials.mnist import input_data | ||
|
|
||
|
Comment on lines
+234
to
+235
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. its causing an error with TF 2 because I'm importing tf_utils in test_estimator.py.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a function in this file that is exclusively used with TF 2?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. there's nothing specific to TF 2. whatever is used in the TF 2 works for both TF 1 and TF 2. only this get data function is specific to TF 1 |
||
| mnist = input_data.read_data_sets("/tmp/data/", one_hot=True) | ||
| return mnist | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why weren't loss and other regexes set for SessionHook ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because SessionHook doesn't identify these tensors based on regex patterns. They come from global collections maintained by TF. (Not gradients where we manually add them)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are the above being removed btw?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I have loss regex here, then https://github.com/awslabs/sagemaker-debugger/blob/master/smdebug/tensorflow/session.py#L206 starts adding other operations in the graph with the regex "loss" into the loss collection. It gave 105 loss tensors in the trial instead of 1.
since these regex patterns were added for gradient tape support, moved them to grad tape related function so that it doesn't affect other hooks