Skip to content

Commit

Permalink
[TEST] Fix test cache (triton-lang#1588)
Browse files Browse the repository at this point in the history
To avoid puzzling segment fault problems caused by multiprocessing, this
PR:

- Uses "spawn" instead of "fork".
- Define the `instance_descriptor` namedtuple globally.
- Make the `kernel_sub` JITFunction defined by the child process only.
  • Loading branch information
Jokeren authored Apr 28, 2023
1 parent b4437fe commit fc75207
Showing 1 changed file with 20 additions and 17 deletions.
37 changes: 20 additions & 17 deletions python/test/unit/runtime/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,31 +170,34 @@ def kernel_add(a, b, o, N: tl.constexpr):
assert bins[0].asm['ttir'] != bins[1].asm['ttir']


def test_compile_in_subproc() -> None:
instance_descriptor = namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1"])


def compile_fn(config, cc):
@triton.jit
def kernel_sub(a, b, o, N: tl.constexpr):
idx = tl.arange(0, N)
tl.store(o + idx,
tl.load(a + idx) - tl.load(b + idx) * 777)
tl.store(o + idx, tl.load(a + idx) - tl.load(b + idx) * 777)
triton.compile(
fn=kernel_sub,
signature={0: "*fp32", 1: "*fp32", 2: "*fp32"},
device=0,
constants={3: 32},
configs=[config],
warm_cache_only=True,
cc=cc,
)


def test_compile_in_subproc() -> None:
major, minor = torch.cuda.get_device_capability(0)
cc = major * 10 + minor
config = namedtuple("instance_descriptor", [
"divisible_by_16", "equal_to_1"])(
tuple(range(4)),
())
config = instance_descriptor(tuple(range(4)), ())

multiprocessing.set_start_method('spawn')
proc = multiprocessing.Process(
target=triton.compile,
kwargs=dict(
fn=kernel_sub,
signature={0: "*fp32", 1: "*fp32", 2: "*fp32"},
device=0,
constants={3: 32},
configs=[config],
warm_cache_only=True,
cc=cc,
))
target=compile_fn,
args=(config, cc))
proc.start()
proc.join()
assert proc.exitcode == 0
Expand Down

0 comments on commit fc75207

Please sign in to comment.