From fc75207290ad7697d2c42304f9389921737e088c Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Fri, 28 Apr 2023 10:39:06 -0400 Subject: [PATCH] [TEST] Fix test cache (#1588) To avoid puzzling segment fault problems caused by multiprocessing, this PR: - Uses "spawn" instead of "fork". - Define the `instance_descriptor` namedtuple globally. - Make the `kernel_sub` JITFunction defined by the child process only. --- python/test/unit/runtime/test_cache.py | 37 ++++++++++++++------------ 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/python/test/unit/runtime/test_cache.py b/python/test/unit/runtime/test_cache.py index 38d396cf0a1f..02b9185eab39 100644 --- a/python/test/unit/runtime/test_cache.py +++ b/python/test/unit/runtime/test_cache.py @@ -170,31 +170,34 @@ def kernel_add(a, b, o, N: tl.constexpr): assert bins[0].asm['ttir'] != bins[1].asm['ttir'] -def test_compile_in_subproc() -> None: +instance_descriptor = namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1"]) + + +def compile_fn(config, cc): @triton.jit def kernel_sub(a, b, o, N: tl.constexpr): idx = tl.arange(0, N) - tl.store(o + idx, - tl.load(a + idx) - tl.load(b + idx) * 777) + tl.store(o + idx, tl.load(a + idx) - tl.load(b + idx) * 777) + triton.compile( + fn=kernel_sub, + signature={0: "*fp32", 1: "*fp32", 2: "*fp32"}, + device=0, + constants={3: 32}, + configs=[config], + warm_cache_only=True, + cc=cc, + ) + +def test_compile_in_subproc() -> None: major, minor = torch.cuda.get_device_capability(0) cc = major * 10 + minor - config = namedtuple("instance_descriptor", [ - "divisible_by_16", "equal_to_1"])( - tuple(range(4)), - ()) + config = instance_descriptor(tuple(range(4)), ()) + multiprocessing.set_start_method('spawn') proc = multiprocessing.Process( - target=triton.compile, - kwargs=dict( - fn=kernel_sub, - signature={0: "*fp32", 1: "*fp32", 2: "*fp32"}, - device=0, - constants={3: 32}, - configs=[config], - warm_cache_only=True, - cc=cc, - )) + target=compile_fn, + args=(config, cc)) proc.start() proc.join() assert proc.exitcode == 0