@@ -325,23 +325,17 @@ def find_source_slices(
325325 assert len (local_slice ) == len (tensor .shape )
326326 ndim = len (tensor .shape )
327327
328- def slice_intersect (a : slice , b : slice , dim_len : int ):
329- a_start , a_stop , a_step = a .indices (dim_len )
330- b_start , b_stop , b_step = b .indices (dim_len )
331- if a_step != 1 or b_step != 1 :
332- raise NotImplementedError ("Only support step size of 1" )
333- start = max (a_start , b_start )
334- stop = min (a_stop , b_stop )
328+ def slice_intersect (a : slice , b : slice ):
329+ start = max (a .start , b .start )
330+ stop = min (a .stop , b .stop )
335331 if start >= stop :
336332 return None
337333 return slice (start , stop , 1 )
338334
339335 for src_key , sl_src , sl_dst in tensor .slices :
340336 intersection = []
341337 for i in range (ndim ):
342- inter = slice_intersect (
343- local_slice [i ], sl_dst [i ], tensor .shape [i ]
344- )
338+ inter = slice_intersect (local_slice [i ], sl_dst [i ])
345339 if inter is None :
346340 break
347341 intersection .append (inter )
@@ -351,11 +345,11 @@ def slice_intersect(a: slice, b: slice, dim_len: int):
351345 for i in range (ndim ):
352346 dst = sl_dst [i ]
353347 src = sl_src [i ]
354- dim_len = tensor . shape [ i ]
355- dst_start , _ , _ = dst . indices ( dim_len )
356- src_start , _ , _ = src . indices ( dim_len )
357- inter_start , inter_stop , _ = intersection [i ].indices (
358- dim_len
348+ dst_start = dst . start
349+ src_start = src . start
350+ inter_start , inter_stop = (
351+ intersection [i ].start ,
352+ intersection [ i ]. stop ,
359353 )
360354 offset = inter_start - dst_start
361355 src_inter_start = src_start + offset
0 commit comments