Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions torchrec/distributed/benchmark/benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,14 +106,20 @@ class BenchmarkResult:
max_mem_allocated: List[int] # megabytes
rank: int = -1

def runtime_percentile(self, percentile: int = 50) -> torch.Tensor:
def runtime_percentile(
self, percentile: int = 50, interpolation: str = "nearest"
) -> torch.Tensor:
return torch.quantile(
self.elapsed_time, percentile / 100.0, interpolation="nearest"
self.elapsed_time,
percentile / 100.0,
interpolation=interpolation,
)

def max_mem_percentile(self, percentile: int = 50) -> torch.Tensor:
def max_mem_percentile(
self, percentile: int = 50, interpolation: str = "nearest"
) -> torch.Tensor:
max_mem = torch.tensor(self.max_mem_allocated, dtype=torch.float)
return torch.quantile(max_mem, percentile / 100.0, interpolation="nearest")
return torch.quantile(max_mem, percentile / 100.0, interpolation=interpolation)


class ECWrapper(torch.nn.Module):
Expand Down
61 changes: 41 additions & 20 deletions torchrec/distributed/test_utils/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ def generate(
long_indices: bool = True,
tables_pooling: Optional[List[int]] = None,
weighted_tables_pooling: Optional[List[int]] = None,
randomize_indices: bool = True,
device: Optional[torch.device] = None,
) -> Tuple["ModelInput", List["ModelInput"]]:
"""
Returns a global (single-rank training) batch
Expand Down Expand Up @@ -132,15 +134,16 @@ def _validate_pooling_factor(
idlist_pooling_factor[idx],
idlist_pooling_factor[idx] / 10,
[batch_size * world_size],
device=device,
),
torch.tensor(1.0),
torch.tensor(1.0, device=device),
).int()
else:
lengths_ = torch.abs(
torch.randn(batch_size * world_size) + pooling_avg
torch.randn(batch_size * world_size, device=device) + pooling_avg,
).int()
if variable_batch_size:
lengths = torch.zeros(batch_size * world_size).int()
lengths = torch.zeros(batch_size * world_size, device=device).int()
for r in range(world_size):
lengths[r * batch_size : r * batch_size + batch_size_by_rank[r]] = (
lengths_[
Expand All @@ -150,12 +153,20 @@ def _validate_pooling_factor(
else:
lengths = lengths_
num_indices = cast(int, torch.sum(lengths).item())
indices = torch.randint(
0,
ind_range,
(num_indices,),
dtype=torch.long if long_indices else torch.int32,
)
if randomize_indices:
indices = torch.randint(
0,
ind_range,
(num_indices,),
dtype=torch.long if long_indices else torch.int32,
device=device,
)
else:
indices = torch.zeros(
(num_indices),
dtype=torch.long if long_indices else torch.int32,
device=device,
)
global_idlist_lengths.append(lengths)
global_idlist_indices.append(indices)
global_idlist_kjt = KeyedJaggedTensor(
Expand All @@ -167,15 +178,15 @@ def _validate_pooling_factor(
for idx in range(len(idscore_ind_ranges)):
ind_range = idscore_ind_ranges[idx]
lengths_ = torch.abs(
torch.randn(batch_size * world_size)
torch.randn(batch_size * world_size, device=device)
+ (
idscore_pooling_factor[idx]
if idscore_pooling_factor
else pooling_avg
)
).int()
if variable_batch_size:
lengths = torch.zeros(batch_size * world_size).int()
lengths = torch.zeros(batch_size * world_size, device=device).int()
for r in range(world_size):
lengths[r * batch_size : r * batch_size + batch_size_by_rank[r]] = (
lengths_[
Expand All @@ -185,13 +196,21 @@ def _validate_pooling_factor(
else:
lengths = lengths_
num_indices = cast(int, torch.sum(lengths).item())
indices = torch.randint(
0,
ind_range,
(num_indices,),
dtype=torch.long if long_indices else torch.int32,
)
weights = torch.rand((num_indices,))
if randomize_indices:
indices = torch.randint(
0,
ind_range,
(num_indices,),
dtype=torch.long if long_indices else torch.int32,
device=device,
)
else:
indices = torch.zeros(
(num_indices),
dtype=torch.long if long_indices else torch.int32,
device=device,
)
weights = torch.rand((num_indices,), device=device)
global_idscore_lengths.append(lengths)
global_idscore_indices.append(indices)
global_idscore_weights.append(weights)
Expand All @@ -206,8 +225,10 @@ def _validate_pooling_factor(
else None
)

global_float = torch.rand((batch_size * world_size, num_float_features))
global_label = torch.rand(batch_size * world_size)
global_float = torch.rand(
(batch_size * world_size, num_float_features), device=device
)
global_label = torch.rand(batch_size * world_size, device=device)

# Split global batch into local batches.
local_inputs = []
Expand Down
Loading