Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

device_prop.hpp: move static map to helper function and initialize there #1763

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from

Conversation

coconutruben
Copy link

Summary:

Why

  • This causes hard to debug segfaults when running in inductor without ASan for some reason with the helper, it's still a static const, so it should only initialize once: on the first call

What

  • move the static name lookup map up to file scope and out of the inline'd get_device_name function

Test Plan:

Something like this without ASan

import torch

import torch.nn as nn
from torch._inductor import config as inductor_config
from torch._inductor.utils import fresh_inductor_cache

class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x, y):
        return torch.mm(x, y)

M, N, K = 128, 128, 128
dtype = torch.float16
A = torch.randn(M, K, dtype=dtype).cuda()
B = torch.randn(K, N, dtype=dtype).cuda()

# create a fresh inductor cache
with fresh_inductor_cache():
    # sample the different backends independently
    with inductor_config.patch(
        {"max_autotune_gemm_backends": f"ATEN,CK"}
    ):
        # compile the model
        compiled_model = torch.compile(SimpleModel(), mode="max-autotune")
        # run the compiled model
        _ = compiled_model(A, B)

Checklist

Please put an x into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.

  • I have added tests relevant to the introduced functionality, and the unit tests are passing locally
  • I have added inline documentation which enables the maintainers with understanding the motivation
  • I have removed the stale documentation which is no longer relevant after this pull request
  • (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request
  • I have run clang-format on all changed files
  • Any dependent changes have been merged

Summary:

\# Why
  - This causes hard to debug segfaults when running in inductor without ASan
    for some reason with the helper, it's still a static const, so it should
    only initialize once: on the first call

\# What
  - move the static name lookup map up to file scope and out of the inline'd
    get_device_name function

Test Plan:

Something like this without ASan

```
import torch

import torch.nn as nn
from torch._inductor import config as inductor_config
from torch._inductor.utils import fresh_inductor_cache

class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x, y):
        return torch.mm(x, y)

M, N, K = 128, 128, 128
dtype = torch.float16
A = torch.randn(M, K, dtype=dtype).cuda()
B = torch.randn(K, N, dtype=dtype).cuda()

\# create a fresh inductor cache
with fresh_inductor_cache():
    # sample the different backends independently
    with inductor_config.patch(
        {"max_autotune_gemm_backends": f"ATEN,CK"}
    ):
        # compile the model
        compiled_model = torch.compile(SimpleModel(), mode="max-autotune")
        # run the compiled model
        _ = compiled_model(A, B)
```
@zjing14
Copy link
Contributor

zjing14 commented Dec 18, 2024

@illsilin @carlushuang Could you review it?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants