@@ -257,6 +257,7 @@ def index(
257257                )
258258            else :
259259                dim_tensor_shape_mult_d1  =  transpose_tensor_shape [i ]
260+ 
260261            mult_d1  =  convert_binary_elementwise (
261262                ctx ,
262263                target ,
@@ -548,6 +549,9 @@ def index_put_converter(
548549    accumulate : bool  =  False ,
549550) ->  TRTTensor :
550551    # Convert 'input_indices' to TRT tensors (or keep None as is) 
552+     input_indices  =  expand_boolean_indices (
553+         ctx , target , source_ir , name , input_tensor , input_indices 
554+     )
551555    indices : List [Optional [Union [TRTTensor , None ]]] =  []
552556    for  i , idx  in  enumerate (input_indices ):
553557        if  idx  is  None :
@@ -571,22 +575,40 @@ def index_put_converter(
571575    K  =  len (I )
572576    # Determine the maximum size 'N' among the index tensors 
573577    if  K  >  0 :
574-         index_shapes  =  [tensor .shape [0 ] for  tensor  in  indices  if  tensor  is  not None ]
578+         index_shapes  =  (
579+             []
580+         )  # [tensor.shape[0] for tensor in indices if tensor is not None] 
581+         for  idx_tensor  in  indices :
582+             if  idx_tensor  is  not None :
583+                 if  idx_tensor .shape [0 ] !=  DYNAMIC_DIM :
584+                     index_shapes .append (idx_tensor .shape [0 ])
585+                 else :
586+                     index_shapes .append (
587+                         get_shape (
588+                             ctx ,
589+                             target ,
590+                             source_ir ,
591+                             name  +  "idx_shape_dim_0" ,
592+                             idx_tensor ,
593+                             0 ,
594+                         )
595+                     )
575596        N  =  max (index_shapes ) if  index_shapes  else  1 
576597    else :
577598        N  =  1 
578599
579600    # Compute shapes and volume for the free dimensions 
580601    F_shapes  =  [input_tensor .shape [i ] for  i  in  F ]
602+     assert  - 1  not  in F_shapes , "Dynamic shape in free dimensions is not supported" 
581603    F_volume  =  trt .volume (F_shapes ) if  F_shapes  else  1 
582604
583605    # Process indexed dimensions (I) 
584606    I_tensors  =  []
585607    for  i  in  I :
586608        idx  =  indices [i ]
587609        assert  idx  is  not None 
588-         idx_reshaped  =  impl .shuffle . reshape (
589-             ctx , target , source_ir , f"{ name } _reshape_idx_I_ { i }  , idx , ( idx . shape [ 0 ],  1 ) 
610+         idx_reshaped  =  impl .unsqueeze . unsqueeze (
611+             ctx , target , source_ir , f"{ name } _unsqueeze_idx_I_ { i }  , idx , 1 
590612        )
591613        expanded_idx  =  impl .slice .expand (
592614            ctx ,
@@ -608,46 +630,50 @@ def index_put_converter(
608630            )
609631            arange_tensors .append (arange_tensor )
610632
611-         meshgrid_tensors  =  []
612-         for  i , arange  in  enumerate (arange_tensors ):
613-             reshape_shape  =  [1 ] *  len (F )
614-             reshape_shape [i ] =  F_shapes [i ]
615-             arange_reshaped  =  impl .shuffle .reshape (
616-                 ctx ,
617-                 target ,
618-                 source_ir ,
619-                 f"{ name } { F [i ]}  ,
620-                 arange ,
621-                 tuple (reshape_shape ),
622-             )
623-             expanded_arange  =  impl .slice .expand (
624-                 ctx ,
625-                 target ,
626-                 source_ir ,
627-                 f"{ name } { F [i ]}  ,
628-                 arange_reshaped ,
629-                 tuple (F_shapes ),
630-             )
631-             meshgrid_tensors .append (expanded_arange )
632- 
633-         meshgrid_stacked  =  impl .cat .cat (
634-             ctx ,
635-             target ,
636-             source_ir ,
637-             f"{ name }  ,
638-             [
639-                 impl .shuffle .reshape (
633+         if  len (arange_tensors ) ==  1 :
634+             # No need to stack 
635+             meshgrid_stacked  =  arange_tensors [0 ]
636+         else :
637+             meshgrid_tensors  =  []
638+             for  i , arange  in  enumerate (arange_tensors ):
639+                 reshape_shape  =  [1 ] *  len (F )
640+                 reshape_shape [i ] =  F_shapes [i ]
641+                 arange_reshaped  =  impl .shuffle .reshape (
640642                    ctx ,
641643                    target ,
642644                    source_ir ,
643-                     f"{ name } _reshape_mesh_ { i }  ,
644-                     t ,
645-                     ( * F_shapes ,  1 ),
645+                     f"{ name } _reshape_arange_F_ { F [ i ] }  ,
646+                     arange ,
647+                     tuple ( reshape_shape ),
646648                )
647-                 for  i , t  in  enumerate (meshgrid_tensors )
648-             ],
649-             dim = - 1 ,
650-         )
649+                 expanded_arange  =  impl .slice .expand (
650+                     ctx ,
651+                     target ,
652+                     source_ir ,
653+                     f"{ name } { F [i ]}  ,
654+                     arange_reshaped ,
655+                     tuple (F_shapes ),
656+                 )
657+                 meshgrid_tensors .append (expanded_arange )
658+ 
659+             meshgrid_stacked  =  impl .cat .cat (
660+                 ctx ,
661+                 target ,
662+                 source_ir ,
663+                 f"{ name }  ,
664+                 [
665+                     impl .shuffle .reshape (
666+                         ctx ,
667+                         target ,
668+                         source_ir ,
669+                         f"{ name } { i }  ,
670+                         t ,
671+                         (* F_shapes , 1 ),
672+                     )
673+                     for  i , t  in  enumerate (meshgrid_tensors )
674+                 ],
675+                 dim = - 1 ,
676+             )
651677        meshgrid_reshaped  =  impl .shuffle .reshape (
652678            ctx ,
653679            target ,
@@ -672,21 +698,15 @@ def index_put_converter(
672698
673699    # Combine all indexed dimensions (I) 
674700    if  K  >  0 :
675-         I_combined  =  impl .cat .cat (
676-             ctx ,
677-             target ,
678-             source_ir ,
679-             f"{ name }  ,
680-             [
681-                 impl .shuffle .reshape (
682-                     ctx , target , source_ir , f"{ name } { i }  , t , (N , F_volume , 1 )
683-                 )
684-                 for  i , t  in  enumerate (I_tensors )
685-             ],
686-             dim = 2 ,
687-         )
701+ 
702+         I_combined  =  [
703+             impl .shuffle .reshape (
704+                 ctx , target , source_ir , f"{ name } { i }  , t , (N , F_volume , 1 )
705+             )
706+             for  i , t  in  enumerate (I_tensors )
707+         ]
688708    else :
689-         I_combined  =  None 
709+         I_combined  =  [] 
690710
691711    # Build the final index list (ii_list) by slicing either I_combined or meshgrid_expanded 
692712    ii_list  =  []
@@ -695,24 +715,12 @@ def index_put_converter(
695715    for  dim  in  range (rank ):
696716        unique_suffix  =  f"{ dim } { i_idx  if  dim  in  I  else  f_idx }  
697717        if  dim  in  I :
698-             start  =  [0 , 0 , i_idx ]
699-             shape  =  [N , F_volume , 1 ]
700-             stride  =  [1 , 1 , 1 ]
701-             idx_tensor  =  impl .slice .slice (
702-                 ctx ,
703-                 target ,
704-                 source_ir ,
705-                 f"{ name } { unique_suffix }  ,
706-                 I_combined ,
707-                 start ,
708-                 shape ,
709-                 stride ,
710-             )
718+             idx_tensor  =  I_combined [i_idx ]
711719            ii_list .append (idx_tensor )
712720            i_idx  +=  1 
713721        else :
714722            start  =  [0 , 0 , f_idx ]
715-             shape  =  [N , F_volume , 1 ]
723+             shape  =  [- 1 ,  F_volume ,  1 ]  if   isinstance ( N ,  TRTTensor )  else  [ N , F_volume , 1 ]
716724            stride  =  [1 , 1 , 1 ]
717725            mesh_tensor  =  impl .slice .slice (
718726                ctx ,
@@ -731,20 +739,24 @@ def index_put_converter(
731739    indices_cat  =  impl .cat .cat (
732740        ctx , target , source_ir , f"{ name }  , ii_list , dim = 2 
733741    )
742+ 
743+     # Flatten the indices_cat to (N * F_volume, rank) 
734744    indices_cat  =  impl .shuffle .reshape (
735745        ctx ,
736746        target ,
737747        source_ir ,
738748        f"{ name }  ,
739749        indices_cat ,
740-         (N   *   F_volume , rank ),
750+         (- 1 , rank ),
741751    )
742752
743753    if  not  isinstance (values , TRTTensor ):
744754        values  =  get_trt_tensor (ctx , values , f"{ name }  , min_rank = 0 )
745755
746756    # Define the expected shape based on (N,) + F_shapes 
747-     expected_shape  =  (N ,) +  tuple (F_shapes )
757+     expected_shape  =  (
758+         (- 1 ,) +  tuple (F_shapes ) if  isinstance (N , TRTTensor ) else  (N ,) +  tuple (F_shapes )
759+     )
748760
749761    # Broadcast 'values' to match the expected shape 
750762    if  len (values .shape ) ==  0  or  values .shape  ==  (1 ,):  # Scalar case 
@@ -761,7 +773,12 @@ def index_put_converter(
761773        )
762774    else :  # Non-scalar case 
763775        values_shape  =  list (values .shape )
764-         if  K  >  0  and  N  in  values_shape :
776+         if  (
777+             K  >  0 
778+             and  N  in  values_shape 
779+             and  (len (F ) >  1  and  max (F ) -  min (F ) +  1  ==  len (F ))
780+         ):
781+             # Continuous case 
765782            n_idx  =  values_shape .index (N )
766783            permute_order  =  [n_idx ] +  [
767784                i  for  i  in  range (len (values_shape )) if  i  !=  n_idx 
@@ -807,31 +824,27 @@ def index_put_converter(
807824                tuple (broadcast_shape ),
808825            )
809826        else :
827+             # Discontinuous case 
810828            values_shape_padded  =  [1 ] *  (
811829                len (expected_shape ) -  len (values .shape )
812830            ) +  list (values .shape )
813831            broadcast_shape  =  []
814832            for  exp_dim , val_dim  in  zip (expected_shape , values_shape_padded ):
815-                 if  val_dim  ==  1  or  exp_dim  ==  val_dim :
833+                 if  val_dim  ==  DYNAMIC_DIM  or  exp_dim  ==  DYNAMIC_DIM :
834+                     broadcast_shape .append (- 1 )
835+                 elif  val_dim  ==  1  or  exp_dim  ==  val_dim :
816836                    broadcast_shape .append (exp_dim )
817837                else :
818838                    raise  ValueError (
819839                        f"Cannot broadcast { values .shape } { expected_shape }  
820840                    )
821-             values_reshaped  =  impl .shuffle .reshape (
822-                 ctx ,
823-                 target ,
824-                 source_ir ,
825-                 f"{ name }  ,
826-                 values ,
827-                 tuple (broadcast_shape ),
828-             )
841+ 
829842            values_expanded  =  impl .slice .expand (
830843                ctx ,
831844                target ,
832845                source_ir ,
833846                f"{ name }  ,
834-                 values_reshaped ,
847+                 values ,
835848                expected_shape ,
836849            )
837850
@@ -842,16 +855,51 @@ def index_put_converter(
842855        source_ir ,
843856        f"{ name }  ,
844857        values_expanded ,
845-         (N   *   F_volume ,),
858+         (- 1 ,),
846859    )
847- 
848860    indices_cat  =  cast_trt_tensor (ctx , indices_cat , trt .int32 , f"{ name }  )
849-     # Perform Scatter ND operation 
850-     scatter_layer  =  ctx .net .add_scatter (
851-         input_tensor ,
852-         indices_cat ,
853-         flattened_values ,
854-         trt .ScatterMode .ND  if  not  accumulate  else  trt .ScatterMode .ND_ELEMENTWISE_ADD ,
855-     )
856-     set_layer_name (scatter_layer , target , f"{ name }  , source_ir )
857-     return  scatter_layer .get_output (0 )
861+     if  accumulate :
862+         zero_tensor  =  impl .full .full (
863+             ctx ,
864+             target ,
865+             source_ir ,
866+             f"{ name }  ,
867+             [
868+                 get_shape (
869+                     ctx ,
870+                     target ,
871+                     source_ir ,
872+                     name  +  f"input_tensor_shape_dim_{ i }  ,
873+                     input_tensor ,
874+                     i ,
875+                 )
876+                 for  i  in  range (len (input_tensor .shape ))
877+             ],
878+             0.0 ,
879+             dtype = input_tensor .dtype ,
880+         )
881+         # Perform Scatter ND operation 
882+         scatter_layer  =  ctx .net .add_scatter (
883+             zero_tensor ,
884+             indices_cat ,
885+             flattened_values ,
886+             trt .ScatterMode .ND ,
887+         )
888+         set_layer_name (scatter_layer , target , f"{ name }  , source_ir )
889+ 
890+         scatter_out  =  scatter_layer .get_output (0 )
891+         result  =  impl .elementwise .add (
892+             ctx , target , source_ir , f"{ name }  , scatter_out , input_tensor 
893+         )
894+         return  result 
895+ 
896+     else :
897+         scatter_layer  =  ctx .net .add_scatter (
898+             input_tensor ,
899+             indices_cat ,
900+             flattened_values ,
901+             trt .ScatterMode .ND ,
902+         )
903+         set_layer_name (scatter_layer , target , f"{ name }  , source_ir )
904+         scatter_out  =  scatter_layer .get_output (0 )
905+         return  scatter_out 
0 commit comments