Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] authored and carmocca committed Jun 1, 2021
1 parent 89149ac commit d7823e3
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 22 deletions.
20 changes: 7 additions & 13 deletions pytorch_lightning/trainer/connectors/logger_connector/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Generator, Mapping, Sequence
from collections.abc import Generator
from dataclasses import dataclass
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union

Expand Down Expand Up @@ -387,20 +387,14 @@ def get_metrics(self, on_step: bool) -> Dict[str, Dict[str, torch.Tensor]]:
value = apply_to_collection(result_metric, ResultMetric, fn, include_none=False)

# detect if the value is None. This can be nested.
is_empty = True
is_none = False

def is_empty_fn(v):
nonlocal is_empty
# update is_empty if any value is not None.
if v is not None:
is_empty = False
def any_none(_):
nonlocal is_none
is_none = True

# apply detection.
# TODO(@tchaton): need to find a way to support NamedTuple
apply_to_collection(value, object, is_empty_fn, wrong_dtype=(Mapping, Sequence))

# skip is the value was actually empty.
if is_empty:
apply_to_collection(value, type(None), any_none)
if is_none:
continue

# extract metadata
Expand Down
12 changes: 3 additions & 9 deletions tests/core/test_metric_result_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,12 +194,8 @@ def test_result_collection_restoration():
cumulative_sum += i

result.log('training_step', 'a', metric_a, on_step=True, on_epoch=True, metric_attribute="metric_a")
result.log(
'training_step', 'b', metric_b, on_step=False, on_epoch=True, metric_attribute="metric_b"
)
result.log(
'training_step', 'c', metric_c, on_step=True, on_epoch=False, metric_attribute="metric_c"
)
result.log('training_step', 'b', metric_b, on_step=False, on_epoch=True, metric_attribute="metric_b")
result.log('training_step', 'c', metric_c, on_step=True, on_epoch=False, metric_attribute="metric_c")
result.log('training_step', 'a_1', a, on_step=True, on_epoch=True)
result.log('training_step', 'b_1', b, on_step=False, on_epoch=True)
result.log('training_step', 'c_1', {'1': c, '2': c}, on_step=True, on_epoch=False)
Expand Down Expand Up @@ -237,9 +233,7 @@ def test_result_collection_restoration():
else:
assert epoch_log[k] == 1

result.log(
'train_epoch_end', 'a', metric_a, on_step=False, on_epoch=True, metric_attribute="metric_a_end"
)
result.log('train_epoch_end', 'a', metric_a, on_step=False, on_epoch=True, metric_attribute="metric_a_end")

_result.reset()
result.reset()
Expand Down

0 comments on commit d7823e3

Please sign in to comment.