@@ -991,7 +991,7 @@ def sample_inputs_reduction(op_info, device, dtype, requires_grad, **kwargs):
991
991
supports_multiple_dims: bool = kwargs.get('supports_multiple_dims', True)
992
992
993
993
# TODO(@heitorschueroff) Once all reduction operators are using ReductionOpInfo
994
- # use op_info.genearte_args_kwargs directly.
994
+ # use op_info.generate_args_kwargs directly.
995
995
generate_args_kwargs = kwargs.get('generate_args_kwargs', lambda *args, **kwargs: (yield tuple(), {}))
996
996
997
997
inputs: List[SampleInput] = []
@@ -1170,7 +1170,7 @@ class ReductionOpInfo(OpInfo):
1170
1170
the optional keyword parameters of the ReductionOpInfo constructor.
1171
1171
1172
1172
If a reduction operator does not yet implement the full required API of
1173
- reduction operators, this should be documented by skipping the failing
1173
+ reduction operators, this should be documented by xfailing the failing
1174
1174
tests rather than adding optional parameters to ReductionOpInfo.
1175
1175
1176
1176
NOTE
@@ -2120,6 +2120,16 @@ def sample_inputs_isclose(
2120
2120
yield SampleInput(lhs, args=(rhs,),
2121
2121
kwargs=dict(op_kwargs, rtol=rtol, atol=atol, equal_nan=equal_nan))
2122
2122
2123
+ def sample_inputs_linalg_vecdot(op_info, device, dtype, requires_grad, **kwargs):
2124
+ make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
2125
+ batches = ((), (0,), (1,), (5,))
2126
+ ns = (0, 1, 3, 5)
2127
+ for b, n in product(batches, ns):
2128
+ shape = b + (n,)
2129
+ yield SampleInput(make_arg(shape), args=(make_arg(shape),))
2130
+ for i in range(len(shape)):
2131
+ yield SampleInput(make_arg(shape), args=(make_arg(shape),), kwargs=dict(dim=i))
2132
+
2123
2133
def sample_inputs_t(op_info, device, dtype, requires_grad, **kwargs):
2124
2134
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
2125
2135
return (SampleInput(make_arg((1, 2))),
@@ -9898,6 +9908,13 @@ def ref_pairwise_distance(input1, input2):
9898
9908
gradcheck_wrapper=gradcheck_wrapper_hermitian_input,
9899
9909
decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
9900
9910
),
9911
+ OpInfo('linalg.vecdot',
9912
+ aten_name='linalg_vecdot',
9913
+ ref=lambda x, y, *, dim=-1: (x.conj() * y).sum(dim),
9914
+ dtypes=floating_and_complex_types_and(torch.bfloat16, torch.float16),
9915
+ sample_inputs_func=sample_inputs_linalg_vecdot,
9916
+ supports_forward_ad=True,
9917
+ supports_fwgrad_bwgrad=True),
9901
9918
OpInfo('linalg.cond',
9902
9919
aten_name='linalg_cond',
9903
9920
dtypes=floating_and_complex_types(),
0 commit comments