Skip to content

Commit fb39c1a

Browse files
Moe support (#3811)
Co-authored-by: lanluo-nvidia <lanl@nvidia.com>
1 parent 1e4730c commit fb39c1a

File tree

4 files changed

+264
-97
lines changed

4 files changed

+264
-97
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -888,6 +888,7 @@ def aten_ops_select(
888888

889889
@dynamo_tensorrt_converter(
890890
torch.ops.aten.index_put.default,
891+
supports_dynamic_shapes=True,
891892
)
892893
@enforce_tensor_types(
893894
{
@@ -3168,7 +3169,9 @@ def aten_ops_upsample_bicubic2d(
31683169

31693170

31703171
@dynamo_tensorrt_converter(
3171-
torch.ops.aten.topk.default, capability_validator=topk_validator
3172+
torch.ops.aten.topk.default,
3173+
capability_validator=topk_validator,
3174+
supports_dynamic_shapes=True,
31723175
)
31733176
@enforce_tensor_types(
31743177
{

py/torch_tensorrt/dynamo/conversion/impl/select.py

Lines changed: 139 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,7 @@ def index(
257257
)
258258
else:
259259
dim_tensor_shape_mult_d1 = transpose_tensor_shape[i]
260+
260261
mult_d1 = convert_binary_elementwise(
261262
ctx,
262263
target,
@@ -548,6 +549,9 @@ def index_put_converter(
548549
accumulate: bool = False,
549550
) -> TRTTensor:
550551
# Convert 'input_indices' to TRT tensors (or keep None as is)
552+
input_indices = expand_boolean_indices(
553+
ctx, target, source_ir, name, input_tensor, input_indices
554+
)
551555
indices: List[Optional[Union[TRTTensor, None]]] = []
552556
for i, idx in enumerate(input_indices):
553557
if idx is None:
@@ -571,22 +575,40 @@ def index_put_converter(
571575
K = len(I)
572576
# Determine the maximum size 'N' among the index tensors
573577
if K > 0:
574-
index_shapes = [tensor.shape[0] for tensor in indices if tensor is not None]
578+
index_shapes = (
579+
[]
580+
) # [tensor.shape[0] for tensor in indices if tensor is not None]
581+
for idx_tensor in indices:
582+
if idx_tensor is not None:
583+
if idx_tensor.shape[0] != DYNAMIC_DIM:
584+
index_shapes.append(idx_tensor.shape[0])
585+
else:
586+
index_shapes.append(
587+
get_shape(
588+
ctx,
589+
target,
590+
source_ir,
591+
name + "idx_shape_dim_0",
592+
idx_tensor,
593+
0,
594+
)
595+
)
575596
N = max(index_shapes) if index_shapes else 1
576597
else:
577598
N = 1
578599

579600
# Compute shapes and volume for the free dimensions
580601
F_shapes = [input_tensor.shape[i] for i in F]
602+
assert -1 not in F_shapes, "Dynamic shape in free dimensions is not supported"
581603
F_volume = trt.volume(F_shapes) if F_shapes else 1
582604

583605
# Process indexed dimensions (I)
584606
I_tensors = []
585607
for i in I:
586608
idx = indices[i]
587609
assert idx is not None
588-
idx_reshaped = impl.shuffle.reshape(
589-
ctx, target, source_ir, f"{name}_reshape_idx_I_{i}", idx, (idx.shape[0], 1)
610+
idx_reshaped = impl.unsqueeze.unsqueeze(
611+
ctx, target, source_ir, f"{name}_unsqueeze_idx_I_{i}", idx, 1
590612
)
591613
expanded_idx = impl.slice.expand(
592614
ctx,
@@ -608,46 +630,50 @@ def index_put_converter(
608630
)
609631
arange_tensors.append(arange_tensor)
610632

611-
meshgrid_tensors = []
612-
for i, arange in enumerate(arange_tensors):
613-
reshape_shape = [1] * len(F)
614-
reshape_shape[i] = F_shapes[i]
615-
arange_reshaped = impl.shuffle.reshape(
616-
ctx,
617-
target,
618-
source_ir,
619-
f"{name}_reshape_arange_F_{F[i]}",
620-
arange,
621-
tuple(reshape_shape),
622-
)
623-
expanded_arange = impl.slice.expand(
624-
ctx,
625-
target,
626-
source_ir,
627-
f"{name}_expand_arange_F_{F[i]}",
628-
arange_reshaped,
629-
tuple(F_shapes),
630-
)
631-
meshgrid_tensors.append(expanded_arange)
632-
633-
meshgrid_stacked = impl.cat.cat(
634-
ctx,
635-
target,
636-
source_ir,
637-
f"{name}_stack_meshgrid",
638-
[
639-
impl.shuffle.reshape(
633+
if len(arange_tensors) == 1:
634+
# No need to stack
635+
meshgrid_stacked = arange_tensors[0]
636+
else:
637+
meshgrid_tensors = []
638+
for i, arange in enumerate(arange_tensors):
639+
reshape_shape = [1] * len(F)
640+
reshape_shape[i] = F_shapes[i]
641+
arange_reshaped = impl.shuffle.reshape(
640642
ctx,
641643
target,
642644
source_ir,
643-
f"{name}_reshape_mesh_{i}",
644-
t,
645-
(*F_shapes, 1),
645+
f"{name}_reshape_arange_F_{F[i]}",
646+
arange,
647+
tuple(reshape_shape),
646648
)
647-
for i, t in enumerate(meshgrid_tensors)
648-
],
649-
dim=-1,
650-
)
649+
expanded_arange = impl.slice.expand(
650+
ctx,
651+
target,
652+
source_ir,
653+
f"{name}_expand_arange_F_{F[i]}",
654+
arange_reshaped,
655+
tuple(F_shapes),
656+
)
657+
meshgrid_tensors.append(expanded_arange)
658+
659+
meshgrid_stacked = impl.cat.cat(
660+
ctx,
661+
target,
662+
source_ir,
663+
f"{name}_stack_meshgrid",
664+
[
665+
impl.shuffle.reshape(
666+
ctx,
667+
target,
668+
source_ir,
669+
f"{name}_reshape_mesh_{i}",
670+
t,
671+
(*F_shapes, 1),
672+
)
673+
for i, t in enumerate(meshgrid_tensors)
674+
],
675+
dim=-1,
676+
)
651677
meshgrid_reshaped = impl.shuffle.reshape(
652678
ctx,
653679
target,
@@ -672,21 +698,15 @@ def index_put_converter(
672698

673699
# Combine all indexed dimensions (I)
674700
if K > 0:
675-
I_combined = impl.cat.cat(
676-
ctx,
677-
target,
678-
source_ir,
679-
f"{name}_cat_I",
680-
[
681-
impl.shuffle.reshape(
682-
ctx, target, source_ir, f"{name}_reshape_I_{i}", t, (N, F_volume, 1)
683-
)
684-
for i, t in enumerate(I_tensors)
685-
],
686-
dim=2,
687-
)
701+
702+
I_combined = [
703+
impl.shuffle.reshape(
704+
ctx, target, source_ir, f"{name}_reshape_I_{i}", t, (N, F_volume, 1)
705+
)
706+
for i, t in enumerate(I_tensors)
707+
]
688708
else:
689-
I_combined = None
709+
I_combined = []
690710

691711
# Build the final index list (ii_list) by slicing either I_combined or meshgrid_expanded
692712
ii_list = []
@@ -695,24 +715,12 @@ def index_put_converter(
695715
for dim in range(rank):
696716
unique_suffix = f"{dim}_{i_idx if dim in I else f_idx}"
697717
if dim in I:
698-
start = [0, 0, i_idx]
699-
shape = [N, F_volume, 1]
700-
stride = [1, 1, 1]
701-
idx_tensor = impl.slice.slice(
702-
ctx,
703-
target,
704-
source_ir,
705-
f"{name}_slice_I_dim_{unique_suffix}",
706-
I_combined,
707-
start,
708-
shape,
709-
stride,
710-
)
718+
idx_tensor = I_combined[i_idx]
711719
ii_list.append(idx_tensor)
712720
i_idx += 1
713721
else:
714722
start = [0, 0, f_idx]
715-
shape = [N, F_volume, 1]
723+
shape = [-1, F_volume, 1] if isinstance(N, TRTTensor) else [N, F_volume, 1]
716724
stride = [1, 1, 1]
717725
mesh_tensor = impl.slice.slice(
718726
ctx,
@@ -731,20 +739,24 @@ def index_put_converter(
731739
indices_cat = impl.cat.cat(
732740
ctx, target, source_ir, f"{name}_cat_indices", ii_list, dim=2
733741
)
742+
743+
# Flatten the indices_cat to (N * F_volume, rank)
734744
indices_cat = impl.shuffle.reshape(
735745
ctx,
736746
target,
737747
source_ir,
738748
f"{name}_reshape_indices_cat",
739749
indices_cat,
740-
(N * F_volume, rank),
750+
(-1, rank),
741751
)
742752

743753
if not isinstance(values, TRTTensor):
744754
values = get_trt_tensor(ctx, values, f"{name}_values", min_rank=0)
745755

746756
# Define the expected shape based on (N,) + F_shapes
747-
expected_shape = (N,) + tuple(F_shapes)
757+
expected_shape = (
758+
(-1,) + tuple(F_shapes) if isinstance(N, TRTTensor) else (N,) + tuple(F_shapes)
759+
)
748760

749761
# Broadcast 'values' to match the expected shape
750762
if len(values.shape) == 0 or values.shape == (1,): # Scalar case
@@ -761,7 +773,12 @@ def index_put_converter(
761773
)
762774
else: # Non-scalar case
763775
values_shape = list(values.shape)
764-
if K > 0 and N in values_shape:
776+
if (
777+
K > 0
778+
and N in values_shape
779+
and (len(F) > 1 and max(F) - min(F) + 1 == len(F))
780+
):
781+
# Continuous case
765782
n_idx = values_shape.index(N)
766783
permute_order = [n_idx] + [
767784
i for i in range(len(values_shape)) if i != n_idx
@@ -807,31 +824,27 @@ def index_put_converter(
807824
tuple(broadcast_shape),
808825
)
809826
else:
827+
# Discontinuous case
810828
values_shape_padded = [1] * (
811829
len(expected_shape) - len(values.shape)
812830
) + list(values.shape)
813831
broadcast_shape = []
814832
for exp_dim, val_dim in zip(expected_shape, values_shape_padded):
815-
if val_dim == 1 or exp_dim == val_dim:
833+
if val_dim == DYNAMIC_DIM or exp_dim == DYNAMIC_DIM:
834+
broadcast_shape.append(-1)
835+
elif val_dim == 1 or exp_dim == val_dim:
816836
broadcast_shape.append(exp_dim)
817837
else:
818838
raise ValueError(
819839
f"Cannot broadcast {values.shape} to {expected_shape}"
820840
)
821-
values_reshaped = impl.shuffle.reshape(
822-
ctx,
823-
target,
824-
source_ir,
825-
f"{name}_reshape_values",
826-
values,
827-
tuple(broadcast_shape),
828-
)
841+
829842
values_expanded = impl.slice.expand(
830843
ctx,
831844
target,
832845
source_ir,
833846
f"{name}_expand_values",
834-
values_reshaped,
847+
values,
835848
expected_shape,
836849
)
837850

@@ -842,16 +855,51 @@ def index_put_converter(
842855
source_ir,
843856
f"{name}_flatten_values",
844857
values_expanded,
845-
(N * F_volume,),
858+
(-1,),
846859
)
847-
848860
indices_cat = cast_trt_tensor(ctx, indices_cat, trt.int32, f"{name}_idx_int32")
849-
# Perform Scatter ND operation
850-
scatter_layer = ctx.net.add_scatter(
851-
input_tensor,
852-
indices_cat,
853-
flattened_values,
854-
trt.ScatterMode.ND if not accumulate else trt.ScatterMode.ND_ELEMENTWISE_ADD,
855-
)
856-
set_layer_name(scatter_layer, target, f"{name}_scatter", source_ir)
857-
return scatter_layer.get_output(0)
861+
if accumulate:
862+
zero_tensor = impl.full.full(
863+
ctx,
864+
target,
865+
source_ir,
866+
f"{name}_zero_tensor",
867+
[
868+
get_shape(
869+
ctx,
870+
target,
871+
source_ir,
872+
name + f"input_tensor_shape_dim_{i}",
873+
input_tensor,
874+
i,
875+
)
876+
for i in range(len(input_tensor.shape))
877+
],
878+
0.0,
879+
dtype=input_tensor.dtype,
880+
)
881+
# Perform Scatter ND operation
882+
scatter_layer = ctx.net.add_scatter(
883+
zero_tensor,
884+
indices_cat,
885+
flattened_values,
886+
trt.ScatterMode.ND,
887+
)
888+
set_layer_name(scatter_layer, target, f"{name}_scatter", source_ir)
889+
890+
scatter_out = scatter_layer.get_output(0)
891+
result = impl.elementwise.add(
892+
ctx, target, source_ir, f"{name}_add", scatter_out, input_tensor
893+
)
894+
return result
895+
896+
else:
897+
scatter_layer = ctx.net.add_scatter(
898+
input_tensor,
899+
indices_cat,
900+
flattened_values,
901+
trt.ScatterMode.ND,
902+
)
903+
set_layer_name(scatter_layer, target, f"{name}_scatter", source_ir)
904+
scatter_out = scatter_layer.get_output(0)
905+
return scatter_out

py/torch_tensorrt/dynamo/lowering/passes/remove_num_users_is_0_nodes.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ def remove_num_users_is_0_nodes(
2323
and len(node.all_input_nodes) > 0
2424
):
2525
gm.graph.erase_node(node)
26-
gm = clean_up_graph_after_modifications(gm)
26+
27+
gm = clean_up_graph_after_modifications(gm)
2728

2829
logger.debug(f"Removed ops that [num_users=0] nodes:\n{gm.graph}")
2930

0 commit comments

Comments
 (0)