Skip to content

Commit

Permalink
Leave out serialization for a different PR
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca committed Jun 1, 2021
1 parent d7823e3 commit 37e7589
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 111 deletions.
34 changes: 0 additions & 34 deletions pytorch_lightning/trainer/connectors/logger_connector/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,12 +129,6 @@ def __repr__(self) -> str:
state += f", cumulated_batch_size={self.cumulated_batch_size}"
return f"{self.__class__.__name__}({state})"

# FIXME: necessary? tests pass?
# def __getstate__(self) -> dict:
# d = super().__getstate__()
# # delete reference to TorchMetrics Metric
# d['_modules'].pop('value', None)
# return d


class _SerializationHelper(dict):
Expand Down Expand Up @@ -494,34 +488,6 @@ def _extract_batch_size(self, batch: Any) -> int:
size = 1
return size

def state_dict(self):

def to_state_dict(item: ResultMetric) -> _SerializationHelper:
return _SerializationHelper(**item.__getstate__())

return {k: apply_to_collection(v, ResultMetric, to_state_dict) for k, v in self.items()}

def load_from_state_dict(self, state_dict: Dict[str, Any], metrics: Optional[Dict[str, Metric]] = None) -> None:

def to_result_metric(item: _SerializationHelper) -> ResultMetric:
result_metric = ResultMetric(item["meta"], item["is_tensor"])
result_metric.__dict__.update(item)
return result_metric.to(self.device)

state_dict = {k: apply_to_collection(v, _SerializationHelper, to_result_metric) for k, v in state_dict.items()}
for k, v in state_dict.items():
self[k] = v

if metrics:

def re_assign_metric(item: ResultMetric) -> None:
# metric references are lost during serialization and need to be set back during loading
name = item.meta.metric_attribute
if isinstance(name, str) and name in metrics:
item.value = metrics[name]

apply_to_collection(self, ResultMetric, re_assign_metric)

def __str__(self) -> str:
return f'{self.__class__.__name__}({self.training}, {self.device}, {repr(self)})'

Expand Down
77 changes: 0 additions & 77 deletions tests/core/test_metric_result_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
Expand Down Expand Up @@ -169,81 +167,6 @@ def test_result_metric_integration():
)


def test_result_collection_restoration():

_result = None
metric_a = DummyMetric()
metric_b = DummyMetric()
metric_c = DummyMetric()

result = ResultCollection(True, torch.device("cpu"))

for _ in range(2):

result.on_epoch_end_reached = False
cumulative_sum = 0

for i in range(3):

result.batch_idx = i

a = metric_a(i)
b = metric_b(i)
c = metric_c(i)

cumulative_sum += i

result.log('training_step', 'a', metric_a, on_step=True, on_epoch=True, metric_attribute="metric_a")
result.log('training_step', 'b', metric_b, on_step=False, on_epoch=True, metric_attribute="metric_b")
result.log('training_step', 'c', metric_c, on_step=True, on_epoch=False, metric_attribute="metric_c")
result.log('training_step', 'a_1', a, on_step=True, on_epoch=True)
result.log('training_step', 'b_1', b, on_step=False, on_epoch=True)
result.log('training_step', 'c_1', {'1': c, '2': c}, on_step=True, on_epoch=False)

batch_log = result.metrics[MetricSource.LOG]
assert set(batch_log) == {"a_step", "c", "a_1_step", "c_1"}
assert set(batch_log['c_1']) == {'1', '2'}

_result = deepcopy(result)
state_dict = result.state_dict()

result = ResultCollection(True, torch.device("cpu"))
result.load_from_state_dict(
state_dict, {
"metric_a": metric_a,
"metric_b": metric_b,
"metric_c": metric_c,
"metric_a_end": metric_a
}
)

assert _result.items() == result.items()

result.on_epoch_end_reached = True
_result.on_epoch_end_reached = True

epoch_log = result.metrics[MetricSource.LOG]
_epoch_log = _result.metrics[MetricSource.LOG]
assert epoch_log == _epoch_log

assert set(epoch_log) == {'a_1_epoch', 'a_epoch', 'b', 'b_1'}
for k in epoch_log:
if k in {'a_epoch', 'b'}:
assert epoch_log[k] == cumulative_sum
else:
assert epoch_log[k] == 1

result.log('train_epoch_end', 'a', metric_a, on_step=False, on_epoch=True, metric_attribute="metric_a_end")

_result.reset()
result.reset()

# assert metric state reset to default values
assert metric_a.x == metric_a._defaults['x']
assert metric_b.x == metric_b._defaults['x']
assert metric_c.x == metric_c._defaults['x']


def test_result_collection_simple_loop():

result = ResultCollection(True, torch.device("cpu"))
Expand Down

0 comments on commit 37e7589

Please sign in to comment.