Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove numpy from src/lightning/pytorch and use torch only #17278

Merged
merged 36 commits into from
Apr 15, 2023
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
3bc4545
Replace numpy with torch in logger.py and document code
ishandutta0098 Apr 4, 2023
3808fa5
Replace numpy with torch in simple.py
ishandutta0098 Apr 4, 2023
fd66621
Replace numpy with torch in multiprocessing.py
ishandutta0098 Apr 4, 2023
8ffdf10
Replace numpy with torch in model_summary.py
ishandutta0098 Apr 4, 2023
5aad2db
round aggregated_value to 2 decimal places
ishandutta0098 Apr 4, 2023
c5c72e9
Merge branch 'master' into remove_numpy_pytorch
ishandutta0098 Apr 5, 2023
571c115
Update the doctest for merge_dicts()
ishandutta0098 Apr 5, 2023
d11a092
Merge branch 'remove_numpy_pytorch' of github.com:ishandutta0098/ligh…
ishandutta0098 Apr 5, 2023
5cb95a6
Update doctest for LayerSummary Class in model_summary.py
ishandutta0098 Apr 5, 2023
14bffc2
Update typehint for default_func in merge_dicts() and remove __main__
ishandutta0098 Apr 5, 2023
6453731
Remove __main__
ishandutta0098 Apr 5, 2023
22f7d3b
Remove assert statements from merge_dicts() doctest
ishandutta0098 Apr 6, 2023
5086854
Merge branch 'master' into remove_numpy_pytorch
ishandutta0098 Apr 6, 2023
c6d90d8
Remove doctest checks
ishandutta0098 Apr 6, 2023
43a08e9
Update LayerSummary doctest and num_parameters property
ishandutta0098 Apr 6, 2023
b89404a
Merge branch 'master' into remove_numpy_pytorch
ishandutta0098 Apr 6, 2023
424428b
Merge branch 'master' into remove_numpy_pytorch
ishandutta0098 Apr 7, 2023
fd64560
Revert changes in src/lightning/pytorch/strategies/launchers/multipro…
ishandutta0098 Apr 7, 2023
ae53048
Use Python functions for min, max and sum in merge_dicts()
ishandutta0098 Apr 7, 2023
11406a1
Replace torch with math in model_summary.py
ishandutta0098 Apr 7, 2023
df3e86e
Replace torch with math
ishandutta0098 Apr 7, 2023
73a467d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 7, 2023
8d1096e
Remove additional comments from merge_dicts()
ishandutta0098 Apr 7, 2023
b75db82
Merge branch 'remove_numpy_pytorch' of github.com:ishandutta0098/ligh…
ishandutta0098 Apr 7, 2023
5a39da8
import math in model_summary.py
ishandutta0098 Apr 7, 2023
6afd390
Remove manual int() in num_parameters()
ishandutta0098 Apr 7, 2023
ada4b8e
Simplify implementation of merge_dicts() without torch and restore or…
ishandutta0098 Apr 7, 2023
3d2f238
Restore original merge_dicts()
ishandutta0098 Apr 7, 2023
3b0c083
Revert changes for logger.py
ishandutta0098 Apr 11, 2023
d32e4e7
Change list comprehension to loops for _make_report_extended and _mak…
ishandutta0098 Apr 11, 2023
9f3bd78
Merge branch 'master' into remove_numpy_pytorch
ishandutta0098 Apr 11, 2023
31a1eea
Merge branch 'master' into remove_numpy_pytorch
ishandutta0098 Apr 12, 2023
c3bde08
Remove separate mean calculation for _make_report_extended and _make_…
ishandutta0098 Apr 12, 2023
9d6695e
Update src/lightning/pytorch/utilities/model_summary/model_summary.py
carmocca Apr 12, 2023
d02532a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 12, 2023
92ef081
Merge branch 'master' into remove_numpy_pytorch
Borda Apr 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 35 additions & 10 deletions src/lightning/pytorch/loggers/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from collections import defaultdict
from typing import Any, Callable, Dict, Mapping, Optional, Sequence

import numpy as np
import torch

from lightning.fabric.loggers import Logger as FabricLogger
from lightning.fabric.loggers.logger import _DummyExperiment as DummyExperiment # for backward compatibility
Expand Down Expand Up @@ -94,7 +94,7 @@ def method(*args: Any, **kwargs: Any) -> None:
def merge_dicts( # pragma: no cover
dicts: Sequence[Mapping],
agg_key_funcs: Optional[Mapping] = None,
default_func: Callable[[Sequence[float]], float] = np.mean,
ishandutta0098 marked this conversation as resolved.
Show resolved Hide resolved
default_func: Callable[[torch.Tensor], torch.Tensor] = torch.mean,
) -> Dict:
"""Merge a sequence with dictionaries into one dictionary by aggregating the same keys with some given
function.
Expand All @@ -120,25 +120,50 @@ def merge_dicts( # pragma: no cover
>>> d1 = {'a': 1.7, 'b': 2.0, 'c': 1, 'd': {'d1': 1, 'd3': 3}}
>>> d2 = {'a': 1.1, 'b': 2.2, 'v': 1, 'd': {'d1': 2, 'd2': 3}}
>>> d3 = {'a': 1.1, 'v': 2.3, 'd': {'d3': 3, 'd4': {'d5': 1}}}
>>> dflt_func = min
>>> agg_funcs = {'a': np.mean, 'v': max, 'd': {'d1': sum}}
>>> dflt_func = torch.min
>>> agg_funcs = {'a': torch.mean, 'v': torch.max, 'd': {'d1': torch.sum}}
ishandutta0098 marked this conversation as resolved.
Show resolved Hide resolved
>>> pprint.pprint(merge_dicts([d1, d2, d3], agg_funcs, dflt_func))
{'a': 1.3,
'b': 2.0,
'c': 1,
'd': {'d1': 3, 'd2': 3, 'd3': 3, 'd4': {'d5': 1}},
'v': 2.3}
{'a': tensor(1.3000),
ishandutta0098 marked this conversation as resolved.
Show resolved Hide resolved
'b': tensor(2.),
'c': tensor(1.),
'd': {'d1': tensor(3.),
'd2': tensor(3.),
'd3': tensor(3.),
'd4': {'d5': tensor(1.)}},
'v': tensor(2.3000)}
"""
# If agg_key_funcs is not provided, initialize it as an empty dictionary
ishandutta0098 marked this conversation as resolved.
Show resolved Hide resolved
agg_key_funcs = agg_key_funcs or {}

# Collect all unique keys from the input dictionaries
keys = list(functools.reduce(operator.or_, [set(d.keys()) for d in dicts]))

# Initialize the output dictionary using defaultdict
d_out: Dict = defaultdict(dict)

# Iterate over all unique keys
for k in keys:
# Get the aggregation function for the current key, if available
fn = agg_key_funcs.get(k)

# Collect values associated with the current key from all input dictionaries
values_to_agg = [v for v in [d_in.get(k) for d_in in dicts] if v is not None]

# Check if the values to aggregate are dictionaries
if isinstance(values_to_agg[0], dict):
# Call the merge_dicts function recursively for nested dictionaries
d_out[k] = merge_dicts(values_to_agg, fn, default_func)

else:
d_out[k] = (fn or default_func)(values_to_agg)
# Convert values_to_agg to a tensor with float32 data type
values_to_agg_tensor = torch.tensor(values_to_agg, dtype=torch.float32)

# Apply the aggregation function (fn) or the default function (default_func) to the tensor
aggregated_value = (fn or default_func)(values_to_agg_tensor)

# Assign the aggregated value to the output dictionary
# The check is necessary because aggregation functions can return floats instead of tensors
d_out[k] = aggregated_value if isinstance(aggregated_value, float) else aggregated_value

# Convert the defaultdict to a regular dictionary and return it
return dict(d_out)
15 changes: 12 additions & 3 deletions src/lightning/pytorch/profilers/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
import torch

from lightning.pytorch.profilers.profiler import Profiler

Expand Down Expand Up @@ -80,15 +80,24 @@ def stop(self, action_name: str) -> None:
def _make_report_extended(self) -> Tuple[_TABLE_DATA_EXTENDED, float, float]:
total_duration = time.monotonic() - self.start_time
report = [
(a, np.mean(d), len(d), np.sum(d), 100.0 * np.sum(d) / total_duration)
(
a,
torch.mean(torch.tensor(d)).item(),
ishandutta0098 marked this conversation as resolved.
Show resolved Hide resolved
len(d),
torch.sum(torch.tensor(d)).item(),
100.0 * torch.sum(torch.tensor(d)).item() / total_duration,
)
for a, d in self.recorded_durations.items()
carmocca marked this conversation as resolved.
Show resolved Hide resolved
]
report.sort(key=lambda x: x[4], reverse=True)
total_calls = sum(x[2] for x in report)
return report, total_calls, total_duration

def _make_report(self) -> _TABLE_DATA:
report = [(action, np.mean(d), np.sum(d)) for action, d in self.recorded_durations.items()]
report = [
(action, torch.mean(torch.tensor(d)).item(), torch.sum(torch.tensor(d)).item())
ishandutta0098 marked this conversation as resolved.
Show resolved Hide resolved
for action, d in self.recorded_durations.items()
]
report.sort(key=lambda x: x[1], reverse=True)
return report

Expand Down
8 changes: 2 additions & 6 deletions src/lightning/pytorch/strategies/launchers/multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from multiprocessing.queues import SimpleQueue
from typing import Any, Callable, Dict, List, Literal, NamedTuple, Optional

import numpy as np
import torch
import torch.backends.cudnn
import torch.multiprocessing as mp
Expand Down Expand Up @@ -210,7 +209,6 @@ def _check_torchdistx_support(self) -> None:

def get_extra_results(self, trainer: "pl.Trainer") -> Dict[str, Any]:
"""Gather extra state from the Trainer and return it as a dictionary for sending back to the main process.
To avoid issues with memory sharing, we cast the data to numpy.
ishandutta0098 marked this conversation as resolved.
Show resolved Hide resolved

Args:
trainer: reference to the Trainer.
Expand All @@ -219,9 +217,7 @@ def get_extra_results(self, trainer: "pl.Trainer") -> Dict[str, Any]:
A dictionary with items to send back to the main process where :meth:`update_main_process_results` will
process this output.
"""
callback_metrics: dict = apply_to_collection(
trainer.callback_metrics, Tensor, lambda x: x.cpu().numpy()
) # send as numpy to avoid issues with memory sharing
callback_metrics: dict = apply_to_collection(trainer.callback_metrics, Tensor, lambda x: x.cpu())
return {"callback_metrics": callback_metrics}

def update_main_process_results(self, trainer: "pl.Trainer", extra: Dict[str, Any]) -> None:
Expand All @@ -235,7 +231,7 @@ def update_main_process_results(self, trainer: "pl.Trainer", extra: Dict[str, An
"""
# NOTE: `get_extra_results` needs to be called before
callback_metrics = extra["callback_metrics"]
trainer.callback_metrics.update(apply_to_collection(callback_metrics, np.ndarray, lambda x: torch.tensor(x)))
trainer.callback_metrics.update(callback_metrics)

def kill(self, signum: _SIGNUM) -> None:
for proc in self.procs:
Expand Down
14 changes: 9 additions & 5 deletions src/lightning/pytorch/utilities/model_summary/model_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from collections import OrderedDict
from typing import Any, cast, Dict, List, Optional, Tuple, Union

import numpy as np
import torch
import torch.nn as nn
from torch import Tensor
Expand Down Expand Up @@ -120,8 +119,13 @@ def layer_type(self) -> str:
@property
def num_parameters(self) -> int:
"""Returns the number of parameters in this module."""
return sum(
cast(int, np.prod(p.shape)) if not _is_lazy_weight_tensor(p) else 0 for p in self._module.parameters()
return int(
ishandutta0098 marked this conversation as resolved.
Show resolved Hide resolved
sum(
cast(int, torch.prod(torch.tensor(p.shape, dtype=torch.float32)))
ishandutta0098 marked this conversation as resolved.
Show resolved Hide resolved
if not _is_lazy_weight_tensor(p)
else 0
for p in self._module.parameters()
)
)


Expand Down Expand Up @@ -392,8 +396,8 @@ def get_human_readable_count(number: int) -> str:
"""
assert number >= 0
labels = PARAMETER_NUM_UNITS
num_digits = int(np.floor(np.log10(number)) + 1 if number > 0 else 1)
num_groups = int(np.ceil(num_digits / 3))
num_digits = int(torch.floor(torch.log10(torch.tensor(number, dtype=torch.float32))) + 1 if number > 0 else 1)
num_groups = int(torch.ceil(torch.tensor(num_digits, dtype=torch.float32) / 3))
ishandutta0098 marked this conversation as resolved.
Show resolved Hide resolved
num_groups = min(num_groups, len(labels)) # don't abbreviate beyond trillions
shift = -3 * (num_groups - 1)
number = number * (10**shift)
Expand Down