Skip to content

Commit

Permalink
Move permute_sparse_features tests and abstract impl to fbgemm (#2129)
Browse files Browse the repository at this point in the history
Summary:

Title

Reviewed By: zou3519

Differential Revision: D51211981
  • Loading branch information
williamwen42 authored and facebook-github-bot committed Nov 11, 2023
1 parent 2117dd3 commit b3fcb7a
Show file tree
Hide file tree
Showing 3 changed files with 188 additions and 2 deletions.
24 changes: 24 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/sparse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,3 +251,27 @@ def bounds_check_indices(
max_B: int = -1,
) -> None:
pass


@impl_abstract("fbgemm::permute_sparse_features")
def permute_sparse_features_abstract(
permute: Tensor, lengths: Tensor, indices: Tensor, weights: Optional[Tensor] = None
) -> Tuple[Tensor, Tensor, Optional[Tensor]]:
torch._check(lengths.dtype == indices.dtype)
torch._check(permute.device == lengths.device)
torch._check(permute.device == indices.device)
if weights is not None:
torch._check(permute.device == weights.device)
num_output_features = permute.numel()
B = lengths.size(1)
permuted_lengths = lengths.new_empty(num_output_features, B)
output_size = torch.library.get_ctx().new_dynamic_size()
# pyre-fixme[6]: In call `torch._C.TensorBase.new_empty`, for 1st positional argument,
# expected `Sequence[Union[int, types.SymInt]]` but got `Union[int, torch.SymInt]`
permuted_indices = indices.new_empty(output_size)
permuted_weights = None
if weights is not None:
# pyre-fixme[6]: In call `torch._C.TensorBase.new_empty`, for 1st positional argument,
# expected `Sequence[Union[int, types.SymInt]]` but got `Union[int, torch.SymInt]`
permuted_weights = weights.new_empty(output_size)
return (permuted_lengths, permuted_indices, permuted_weights)
3 changes: 2 additions & 1 deletion fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2738,7 +2738,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def(
"lengths_range_out(Tensor output, Tensor t_in, SymInt[]? shape=None) -> Tensor");
m.def(
"permute_sparse_features(Tensor permute, Tensor lengths, Tensor indices, Tensor? weights=None) -> (Tensor, Tensor, Tensor?)");
"permute_sparse_features(Tensor permute, Tensor lengths, Tensor indices, Tensor? weights=None) -> (Tensor, Tensor, Tensor?)",
{PT2_COMPLIANT_TAG});
m.def("Bfloat16QuantizedToFloat(Tensor input) -> Tensor");
m.def("FloatToBfloat16Quantized(Tensor input) -> Tensor");
m.def(
Expand Down
163 changes: 162 additions & 1 deletion fbgemm_gpu/test/sparse_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import random
import unittest
from itertools import accumulate
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type, Union

import fbgemm_gpu

Expand Down Expand Up @@ -2402,6 +2402,167 @@ def validate(
"grad",
)

def permute_sparse_features_ref_(
self,
lengths: torch.Tensor,
indices: torch.Tensor,
weights: Optional[torch.Tensor],
permute: torch.LongTensor,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
T = lengths.size(0)
B = lengths.size(1)
permuted_lengths = torch.index_select(lengths.view(T, B), 0, permute)

original_segment_lengths = lengths.view(T, B).sum(dim=1, dtype=torch.int32)
original_segment_start = torch.ops.fbgemm.asynchronous_exclusive_cumsum(
original_segment_lengths.view(-1)
)

permuted_indices = []
permuted_weights = []
for i in range(permute.size(0)):
start = original_segment_start[permute[i]]
end = start + original_segment_lengths[permute[i]]
permuted_indices.append(indices[start:end])
if weights is not None:
permuted_weights.append(weights[start:end])

permuted_indices = torch.cat(permuted_indices, dim=0).flatten()

if weights is None:
permuted_weights = None
else:
permuted_weights = torch.cat(permuted_weights, dim=0).flatten()

return permuted_lengths, permuted_indices, permuted_weights

@given(
B=st.integers(min_value=1, max_value=20),
T=st.integers(min_value=1, max_value=20),
L=st.integers(min_value=2, max_value=20),
long_index=st.booleans(),
has_weight=st.booleans(),
)
@settings(max_examples=20, deadline=None)
def test_permute_sparse_features(
self, B: int, T: int, L: int, long_index: bool, has_weight: bool
) -> None:
index_dtype = torch.int64 if long_index else torch.int32
lengths = torch.randint(low=1, high=L, size=(T, B)).type(index_dtype)
weights = torch.rand(int(lengths.sum().item())).float() if has_weight else None
indices = torch.randint(
low=1,
high=int(1e5),
size=cast(Tuple[int, ...], (lengths.sum().item(),)),
).type(index_dtype)
permute_list = list(range(T))
random.shuffle(permute_list)
permute = torch.IntTensor(permute_list)

(
permuted_lengths_cpu,
permuted_indices_cpu,
permuted_weights_cpu,
) = torch.ops.fbgemm.permute_sparse_features(permute, lengths, indices, weights)
(
permuted_lengths_ref,
permuted_indices_ref,
permuted_weights_ref,
# pyre-fixme[6]: For 4th param expected `LongTensor` but got `Tensor`.
) = self.permute_indices_ref_(lengths, indices, weights, permute.long())
torch.testing.assert_close(permuted_indices_cpu, permuted_indices_ref)
torch.testing.assert_close(permuted_lengths_cpu, permuted_lengths_ref)
if has_weight:
torch.testing.assert_close(permuted_weights_cpu, permuted_weights_ref)
else:
assert permuted_weights_cpu is None and permuted_weights_ref is None

if gpu_available:
(
permuted_lengths_gpu,
permuted_indices_gpu,
permuted_weights_gpu,
) = torch.ops.fbgemm.permute_sparse_features(
permute.cuda(),
lengths.cuda(),
indices.cuda(),
weights.cuda() if has_weight and weights is not None else None,
)
torch.testing.assert_close(permuted_indices_gpu.cpu(), permuted_indices_cpu)
torch.testing.assert_close(permuted_lengths_gpu.cpu(), permuted_lengths_cpu)
if has_weight:
torch.testing.assert_close(
permuted_weights_gpu.cpu(), permuted_weights_cpu
)
else:
assert permuted_weights_gpu is None

@given(
B=st.integers(min_value=1, max_value=20),
T=st.integers(min_value=1, max_value=20),
L=st.integers(min_value=2, max_value=20),
long_index=st.booleans(),
has_weight=st.booleans(),
)
@settings(max_examples=20, deadline=None)
def test_permute_sparse_features_with_repeats(
self, B: int, T: int, L: int, long_index: bool, has_weight: bool
) -> None:
index_dtype = torch.int64 if long_index else torch.int32
lengths = torch.randint(low=1, high=L, size=(T, B)).type(index_dtype)
weights = torch.rand(int(lengths.sum().item())).float() if has_weight else None
indices = torch.randint(
low=1,
high=int(1e5),
size=cast(Tuple[int, ...], (lengths.sum().item(),)),
).type(index_dtype)
permute_list = list(range(T))

num_repeats = random.randint(0, T)
for _ in range(num_repeats):
permute_list.append(random.randint(0, T - 1))

random.shuffle(permute_list)
permute = torch.IntTensor(permute_list)

(
permuted_lengths_cpu,
permuted_indices_cpu,
permuted_weights_cpu,
) = torch.ops.fbgemm.permute_sparse_features(permute, lengths, indices, weights)
(
permuted_lengths_ref,
permuted_indices_ref,
permuted_weights_ref,
# pyre-fixme[6]: For 4th param expected `LongTensor` but got `Tensor`.
) = self.permute_indices_ref_(lengths, indices, weights, permute.long())
torch.testing.assert_close(permuted_indices_cpu, permuted_indices_ref)
torch.testing.assert_close(permuted_lengths_cpu, permuted_lengths_ref)
if has_weight:
torch.testing.assert_close(permuted_weights_cpu, permuted_weights_ref)
else:
assert permuted_weights_cpu is None and permuted_weights_ref is None

if gpu_available:
(
permuted_lengths_gpu,
permuted_indices_gpu,
permuted_weights_gpu,
) = torch.ops.fbgemm.permute_sparse_features(
permute.cuda(),
lengths.cuda(),
indices.cuda(),
weights.cuda() if has_weight and weights is not None else None,
)
torch.testing.assert_close(permuted_indices_gpu.cpu(), permuted_indices_cpu)
torch.testing.assert_close(permuted_lengths_gpu.cpu(), permuted_lengths_cpu)
if has_weight:
torch.testing.assert_close(
permuted_weights_gpu.cpu(), permuted_weights_cpu
)
else:
assert permuted_weights_cpu is None


failures_dict_path: str = get_file_path_2(
"", os.path.dirname(__file__), "failures_dict.json"
Expand Down

0 comments on commit b3fcb7a

Please sign in to comment.