Skip to content

Commit a285fc8

Browse files
add comment
1 parent 40526cb commit a285fc8

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

torchao/prototype/moe_training/tensor.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,12 @@ def __torch_function__(cls, func, types, args, kwargs={}):
3838
B_is_3d = B.dim() == 3
3939
has_offs = kwargs.get(cls.offs_arg_name) is not None
4040
if A_is_2d and B_is_3d and has_offs:
41+
# prefer to use B to check use_triton, as that will be the weight/nn.Parameter
42+
# that is converted to ScaledGroupedMMTensor
4143
use_triton = (
42-
A._use_triton_for_per_group_scales
43-
if isinstance(A, cls)
44-
else B._use_triton_for_per_group_scales
44+
B._use_triton_for_per_group_scales
45+
if isinstance(B, cls)
46+
else A._use_triton_for_per_group_scales
4547
)
4648
return _scaled_grouped_mm(
4749
*args,

0 commit comments

Comments
 (0)