@@ -171,10 +171,27 @@ def select_scatter_decomposition(
171171 dim : int ,
172172 index : int ,
173173) -> torch .Tensor :
174- input_tensor .shape [dim ] = torch .le (index , input_tensor .shape [dim ])
174+ # input_tensor.shape[dim] = torch.le(index, input_tensor.shape[dim])
175+ # check if the dim is less than shape
176+ if input_tensor .shape [dim ] < index :
177+ raise AssertionError ("The index should not be greater than dim" )
178+
179+ # expanding the src_tensor to have the same dimension as input_tensor
175180 src_tensor = torch .expand (torch .unsqueeze (src_tensor , dim ), input_tensor .shape )
176- input_tensor_shape = input_tensor .shape
177- return torch .where (torch .eq ((input_tensor_shape [dim ]), index )), src_tensor , input_tensor )
181+ # check if the dimension of the src tensor is same as slice tensor
182+ select_tensor = torch .select (input_tensor , dim , index )
183+ if select_tensor .shape != src_tensor .shape :
184+ raise AssertionError (
185+ "The slice tensor shape should be equal to the src tensor shape"
186+ )
187+
188+ # make the index tensor
189+ # input_tensor_shape = input_tensor.shape
190+ # return torch.where(torch.eq((input_tensor_shape[dim]), index), src_tensor, input_tensor)
191+
192+ unbind_tensors = torch .unbind (input_tensor , dim )
193+ unbind_tensors [index ] = src_tensor
194+ return torch .cat (unbind_tensors , dim )
178195
179196
180197def get_decompositions (
0 commit comments