From 2c42494a99315535abac7f893353db0d2c55e0c8 Mon Sep 17 00:00:00 2001 From: Rahul Huilgol Date: Mon, 25 Nov 2019 15:24:45 -0800 Subject: [PATCH 1/4] Make trial tensors method do intersection update --- smdebug/trials/trial.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/smdebug/trials/trial.py b/smdebug/trials/trial.py index 4487c4a13..1ea2fbb9b 100644 --- a/smdebug/trials/trial.py +++ b/smdebug/trials/trial.py @@ -338,14 +338,13 @@ def tensors(self, *, step=None, mode=ModeKeys.GLOBAL, regex=None, collection=Non return sorted(list(ts)) else: xs = set() - if regex is not None: - xs.update(self._tensors_matching_regex(regex)) if collection is not None: collection_tensors_saved = set(self._tensors.keys()).intersection( self._tensors_in_collection(collection) ) xs.update(collection_tensors_saved) - + if regex is not None: + xs.intersection_update(self._tensors_matching_regex(regex)) return sorted(list(ts.intersection(xs))) def _tensors_for_step(self, step, mode=ModeKeys.GLOBAL) -> list: From 43aad05eb99b7e91fcae8d15b94c17a3abe0f0e2 Mon Sep 17 00:00:00 2001 From: Rahul Huilgol Date: Tue, 26 Nov 2019 12:20:41 -0800 Subject: [PATCH 2/4] Disallow both regex and collection --- smdebug/trials/trial.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/smdebug/trials/trial.py b/smdebug/trials/trial.py index 1ea2fbb9b..81ac2f032 100644 --- a/smdebug/trials/trial.py +++ b/smdebug/trials/trial.py @@ -336,16 +336,14 @@ def tensors(self, *, step=None, mode=ModeKeys.GLOBAL, regex=None, collection=Non if regex is None and collection is None: return sorted(list(ts)) + elif regex is not None and collection is not None: + raise ValueError("Only one of `regex` or `collection` can be passed " "to this method") else: - xs = set() if collection is not None: - collection_tensors_saved = set(self._tensors.keys()).intersection( - self._tensors_in_collection(collection) - ) - xs.update(collection_tensors_saved) - if regex is not None: - xs.intersection_update(self._tensors_matching_regex(regex)) - return sorted(list(ts.intersection(xs))) + xs = set(self._tensors.keys()).intersection(self._tensors_in_collection(collection)) + else: + xs = self._tensors_matching_regex(regex) + return sorted(list(ts.intersection(xs))) def _tensors_for_step(self, step, mode=ModeKeys.GLOBAL) -> list: step = self._mode_to_global[mode][step] if mode != ModeKeys.GLOBAL else step From d8e28af012fd195ff47fb8ca5dc6787b7e304e72 Mon Sep 17 00:00:00 2001 From: Rahul Huilgol Date: Tue, 26 Nov 2019 12:25:48 -0800 Subject: [PATCH 3/4] add test --- tests/analysis/trials/test_tensors_api.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/analysis/trials/test_tensors_api.py b/tests/analysis/trials/test_tensors_api.py index 49e3f1ed7..df4f24289 100644 --- a/tests/analysis/trials/test_tensors_api.py +++ b/tests/analysis/trials/test_tensors_api.py @@ -48,6 +48,12 @@ def test_tensors(out_dir): assert len(tr.tensors(collection="test")) == num_tensors + 2 assert len(tr.tensors(collection=tr.collection("test"))) == num_tensors + 2 + try: + tr.tensors(collection=tr.collection("test"), regex="a") + assert False + except ValueError: + pass + def test_mode_data(): run_id = "trial_" + datetime.now().strftime("%Y%m%d-%H%M%S%f") From 1a7f603abd34dd508c3564f2b69bee7b75fd09d6 Mon Sep 17 00:00:00 2001 From: Rahul Huilgol Date: Tue, 26 Nov 2019 13:36:56 -0800 Subject: [PATCH 4/4] Update trial.py --- smdebug/trials/trial.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/smdebug/trials/trial.py b/smdebug/trials/trial.py index 81ac2f032..892da8dee 100644 --- a/smdebug/trials/trial.py +++ b/smdebug/trials/trial.py @@ -337,7 +337,7 @@ def tensors(self, *, step=None, mode=ModeKeys.GLOBAL, regex=None, collection=Non if regex is None and collection is None: return sorted(list(ts)) elif regex is not None and collection is not None: - raise ValueError("Only one of `regex` or `collection` can be passed " "to this method") + raise ValueError("Only one of `regex` or `collection` can be passed to this method") else: if collection is not None: xs = set(self._tensors.keys()).intersection(self._tensors_in_collection(collection))