Skip to content

Commit 3dabad7

Browse files
committed
fix bug
1 parent a7b5037 commit 3dabad7

File tree

2 files changed

+11
-16
lines changed

2 files changed

+11
-16
lines changed

python/paddle/distributed/flex_checkpoint/aoa/aoa_engine.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -325,23 +325,17 @@ def find_source_slices(
325325
assert len(local_slice) == len(tensor.shape)
326326
ndim = len(tensor.shape)
327327

328-
def slice_intersect(a: slice, b: slice, dim_len: int):
329-
a_start, a_stop, a_step = a.indices(dim_len)
330-
b_start, b_stop, b_step = b.indices(dim_len)
331-
if a_step != 1 or b_step != 1:
332-
raise NotImplementedError("Only support step size of 1")
333-
start = max(a_start, b_start)
334-
stop = min(a_stop, b_stop)
328+
def slice_intersect(a: slice, b: slice):
329+
start = max(a.start, b.start)
330+
stop = min(a.stop, b.stop)
335331
if start >= stop:
336332
return None
337333
return slice(start, stop, 1)
338334

339335
for src_key, sl_src, sl_dst in tensor.slices:
340336
intersection = []
341337
for i in range(ndim):
342-
inter = slice_intersect(
343-
local_slice[i], sl_dst[i], tensor.shape[i]
344-
)
338+
inter = slice_intersect(local_slice[i], sl_dst[i])
345339
if inter is None:
346340
break
347341
intersection.append(inter)
@@ -351,11 +345,11 @@ def slice_intersect(a: slice, b: slice, dim_len: int):
351345
for i in range(ndim):
352346
dst = sl_dst[i]
353347
src = sl_src[i]
354-
dim_len = tensor.shape[i]
355-
dst_start, _, _ = dst.indices(dim_len)
356-
src_start, _, _ = src.indices(dim_len)
357-
inter_start, inter_stop, _ = intersection[i].indices(
358-
dim_len
348+
dst_start = dst.start
349+
src_start = src.start
350+
inter_start, inter_stop = (
351+
intersection[i].start,
352+
intersection[i].stop,
359353
)
360354
offset = inter_start - dst_start
361355
src_inter_start = src_start + offset

python/paddle/distributed/flex_checkpoint/aoa/lexer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import re
1616
from enum import Enum, auto
1717

18-
from .macros import macro_registry
1918

2019
class Token:
2120
def __init__(self, type, value):
@@ -57,6 +56,8 @@ class Lexer:
5756
]
5857

5958
def __init__(self, context):
59+
from .macros import macro_registry
60+
6061
self.macros = [list(d.values())[1] for d in macro_registry.macros]
6162
self.get_token = re.compile(
6263
'|'.join(

0 commit comments

Comments
 (0)