Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion backends/cadence/aot/ops_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ def _validate_ref_impl_exists() -> None:
"cadence::quantized_softmax.per_tensor",
"cadence::quantized_conv2d_nchw", # We should only support per_tensor variant, should remove
"cadence::quantized_relu", # We should only support per_tensor variant, should remove
"cadence::linalg_svd",
"cadence::quantized_conv2d_nhwc", # We should only support per_tensor variant, should remove
"cadence::quantized_softmax",
"cadence::quantized_w8a32_gru",
Expand Down
71 changes: 56 additions & 15 deletions backends/cadence/aot/ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@
# Registry to track all ops with reference implementations
_REGISTERED_REF_IMPLEMENTATIONS: set[str] = set()

_OUTPUTS_TYPE = torch.Tensor | tuple[torch.Tensor, ...]


# Custom impl wrapper that tracks registrations
def impl_tracked(
lib: Library, op_name: str
) -> Callable[[Callable[..., torch.Tensor]], Callable[..., torch.Tensor]]:
) -> Callable[[Callable[..., _OUTPUTS_TYPE]], Callable[..., _OUTPUTS_TYPE]]:
"""Wrapper around impl that tracks registered ops."""
_REGISTERED_REF_IMPLEMENTATIONS.add(op_name)
return impl(lib, op_name)
Expand Down Expand Up @@ -312,7 +314,7 @@ def quantized_add_per_tensor(
dequant_Y = Y_scale * (Y - Y_zero_point)

# q_min/q_max are unused args
return quantize_per_tensor(
out = quantize_per_tensor(
dequant_X + dequant_Y,
out_scale,
out_zero_point,
Expand All @@ -321,6 +323,9 @@ def quantized_add_per_tensor(
dtype,
)

assert isinstance(out, torch.Tensor)
return out


@impl_tracked(m, "quantized_add_asym8sxasym8s_asym8s.per_tensor")
def quantized_add_asym8sxasym8s_asym8s_per_tensor(
Expand All @@ -338,9 +343,11 @@ def quantized_add_asym8sxasym8s_asym8s_per_tensor(
if Y.dtype != torch.int8:
raise ValueError("Y dtype must be torch.int8")

return quantized_add_per_tensor(
out = quantized_add_per_tensor(
X, X_scale, X_zero_point, Y, Y_scale, Y_zero_point, out_scale, out_zero_point
)
assert isinstance(out, torch.Tensor)
return out


@impl_tracked(m, "quantized_add_asym8uxasym8u_asym8u.per_tensor")
Expand All @@ -359,9 +366,11 @@ def quantized_add_asym8uxasym8u_asym8u_per_tensor(
if Y.dtype != torch.uint8:
raise ValueError("Y dtype must be torch.int8")

return quantized_add_per_tensor(
out = quantized_add_per_tensor(
X, X_scale, X_zero_point, Y, Y_scale, Y_zero_point, out_scale, out_zero_point
)
assert isinstance(out, torch.Tensor)
return out


def quantized_linear_common(
Expand Down Expand Up @@ -407,14 +416,16 @@ def quantized_linear_common(
(weight - weight_zero_point).float(),
bias.float(),
)
return quantize_per_tensor(
out = quantize_per_tensor(
out,
out_scale,
out_zero_point,
torch.iinfo(dtype).min,
torch.iinfo(dtype).max,
dtype,
).reshape(*leading_dims, N)
)
assert isinstance(out, torch.Tensor)
return out.reshape(*leading_dims, N)


def quantized_linear_variant(
Expand Down Expand Up @@ -576,14 +587,16 @@ def quantized_matmul(
(X - X_zero_point).float(),
(Y - Y_zero_point).float(),
)
return quantize_per_tensor(
out = quantize_per_tensor(
out,
out_scale,
out_zero_point,
torch.iinfo(X.dtype).min,
torch.iinfo(X.dtype).max,
X.dtype,
)
assert isinstance(out, torch.Tensor)
return out


@impl_tracked(m, "quantized_matmul_asym8sxasym8s_asym8s")
Expand All @@ -603,7 +616,7 @@ def quantized_matmul_asym8sxasym8s_asym8s(
if Y.dtype != torch.int8:
raise ValueError("Y dtype must be torch.int8")

return quantized_matmul(
out = quantized_matmul(
X,
X_zero_point,
Y,
Expand All @@ -614,6 +627,8 @@ def quantized_matmul_asym8sxasym8s_asym8s(
out_zero_point,
transposed,
)
assert isinstance(out, torch.Tensor)
return out


@impl_tracked(m, "quantized_matmul_asym8uxasym8u_asym8u")
Expand All @@ -633,7 +648,7 @@ def quantized_matmul_asym8uxasym8u_asym8u(
if Y.dtype != torch.uint8:
raise ValueError("Y dtype must be torch.uint8")

return quantized_matmul(
out = quantized_matmul(
X,
X_zero_point,
Y,
Expand All @@ -644,6 +659,8 @@ def quantized_matmul_asym8uxasym8u_asym8u(
out_zero_point,
transposed,
)
assert isinstance(out, torch.Tensor)
return out


@impl_tracked(m, "quantized_layer_norm.per_tensor")
Expand Down Expand Up @@ -681,18 +698,21 @@ def quantized_layer_norm_per_tensor(
float_input_tensor = dequantize_per_tensor(
input_tensor, X_scale, X_zero_point, -128, 127, input_tensor.dtype
)
assert isinstance(float_input_tensor, torch.Tensor)
out = torch.nn.functional.layer_norm(
float_input_tensor, normalized_shape, weight, bias, eps=eps
)

return quantize_per_tensor(
out = quantize_per_tensor(
out,
output_scale,
output_zero_point,
torch.iinfo(input_tensor.dtype).min,
torch.iinfo(input_tensor.dtype).max,
input_tensor.dtype,
)
assert isinstance(out, torch.Tensor)
return out


def quantized_conv_per_tensor(
Expand Down Expand Up @@ -754,14 +774,16 @@ def quantized_conv_per_tensor(
else:
raise ValueError("Input tensor must be 3D or 4D")

return quantize_per_tensor(
out = quantize_per_tensor(
float_out,
output_scale,
output_zero_point,
torch.iinfo(input_tensor.dtype).min,
torch.iinfo(input_tensor.dtype).max,
input_tensor.dtype,
)
assert isinstance(out, torch.Tensor)
return out


@impl_tracked(m, "quantized_conv2d_nchw.per_tensor")
Expand Down Expand Up @@ -983,7 +1005,7 @@ def variant(
# Call the appropriate base function
match layout:
case "nchw":
return quantized_conv2d_nchw_per_tensor(
out = quantized_conv2d_nchw_per_tensor(
input_tensor,
weight,
bias,
Expand All @@ -1000,7 +1022,7 @@ def variant(
out_shift,
)
case "nhwc":
return quantized_conv2d_nhwc_per_tensor(
out = quantized_conv2d_nhwc_per_tensor(
input_tensor,
weight,
bias,
Expand All @@ -1019,6 +1041,9 @@ def variant(
case _:
raise ValueError(f"Unknown layout {layout}")

assert isinstance(out, torch.Tensor)
return out

return variant

return decorator
Expand Down Expand Up @@ -1293,14 +1318,16 @@ def quantized_relu_common(
dequantized_X = torch.where(
X > X_zero_point, X - X_zero_point, torch.zeros_like(X)
).to(torch.float32)
return quantize_per_tensor(
out = quantize_per_tensor(
dequantized_X,
out_scale,
out_zero_point,
torch.iinfo(X.dtype).min,
torch.iinfo(X.dtype).max,
X.dtype,
)
assert isinstance(out, torch.Tensor)
return out


def quantized_relu_variant(
Expand Down Expand Up @@ -1557,7 +1584,7 @@ def im2row_per_tensor(
in_zero_point: int,
channel_last: bool = False,
) -> torch.Tensor:
return im2row(
out = im2row(
input_tensor,
kernel_size,
dilation,
Expand All @@ -1566,6 +1593,8 @@ def im2row_per_tensor(
torch.tensor(in_zero_point, dtype=torch.int32),
channel_last,
)
assert isinstance(out, torch.Tensor)
return out


@impl_tracked(m, "transposed_im2row")
Expand Down Expand Up @@ -1773,3 +1802,15 @@ def idma_load(src: torch.Tensor, task_num: int = 0, channel: int = 0) -> torch.T
@impl_tracked(m, "idma_wait")
def idma_wait(src: torch.Tensor, task_num: int = 0, channel: int = 0) -> torch.Tensor:
return src.clone()


@impl_tracked(m, "linalg_svd")
def linalg_svd(
A: torch.Tensor,
full_matrices: bool = False,
compute_uv: bool = True,
driver: str | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert compute_uv
U, S, Vh = torch.linalg.svd(A, full_matrices=full_matrices, driver=driver)
return U.contiguous(), S.contiguous(), Vh.contiguous()
31 changes: 31 additions & 0 deletions backends/cadence/aot/tests/test_ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2632,3 +2632,34 @@ def test_quantized_embedding_byte(
expected_out,
)
)

@expand(
[
*[
(
dtype,
(4, 4),
full_matrices,
)
for dtype in [torch.float32, torch.float64]
for full_matrices in [True, False]
]
]
)
def test_linalg_svd_outputs_are_contiguous(
self,
dtype: torch.dtype,
shape: tuple[int, int],
full_matrices: bool,
) -> None:
m, n = shape
a = torch.eye(m, n, dtype=dtype)

U, S, Vh = torch.ops.cadence.linalg_svd(a, full_matrices)

self.assertTrue(U.is_contiguous(), "U not contiguous")
self.assertTrue(S.is_contiguous(), "S not contiguous")
self.assertTrue(Vh.is_contiguous(), "Vh not contiguous")
self.assertTrue(U.dtype == dtype, "U dtype mismatch")
self.assertTrue(S.dtype == dtype, "S dtype mismatch")
self.assertTrue(Vh.dtype == dtype, "Vh dtype mismatch")
Loading