Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Segmentation Fault after running the PyTorch Ahead-of-time (AOT) export workflow #349

Open
Manas-33 opened this issue Dec 19, 2024 · 2 comments

Comments

@Manas-33
Copy link

After running the sample AOT workflow from the Colab on my local environment, the script runs and produces the expected result. But, it terminates with a "Segmentation Fault (core dumped)" error at the end.

This is the code I ran:

import torch
import iree.turbine.aot as aot
import numpy as np
import iree.runtime as ireert

torch.manual_seed(0)

class LinearModule(torch.nn.Module):
  def __init__(self, in_features, out_features):
    super().__init__()
    self.weight = torch.nn.Parameter(torch.randn(in_features, out_features))
    self.bias = torch.nn.Parameter(torch.randn(out_features))

  def forward(self, input):
    return (input @ self.weight) + self.bias

linear_module = LinearModule(4, 3)
example_arg = torch.randn(4)
export_output = aot.export(linear_module, example_arg)
compiled_binary = export_output.compile(save_to=None)

config = ireert.Config("local-task")
vm_module = ireert.load_vm_module(
    ireert.VmModule.wrap_buffer(config.vm_instance, compiled_binary.map_memory()),
    config,
)

input = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)
result = vm_module.main(input)
print(result.to_host())

I observed that the issue seems to originate from the load_vm_module function in this section of the code:

vm_module = ireert.load_vm_module(
    ireert.VmModule.wrap_buffer(config.vm_instance, compiled_binary.map_memory()),
    config,
)
@ScottTodd
Copy link
Member

This looks like iree-org/iree#17635.

The workaround is:

-    ireert.VmModule.wrap_buffer(config.vm_instance, compiled_binary.map_memory()),
+    ireert.VmModule.copy_buffer(config.vm_instance, compiled_binary.map_memory()),

Sample docs were updated to that in iree-org/iree#18620. Should also update the sample notebooks:

I'm pretty sure those two notebooks are running nightly, so if they are broken then I'd expect to see test failures 🤔 .

the script runs and produces the expected result. But, it terminates with a "Segmentation Fault (core dumped)" error at the end.

Maybe the test runner we have isn't catching the error at the end...

@Manas-33
Copy link
Author

Tried the workaround, it is no longer throwing the error, Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants