Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unable to compile model with GATv2Conv layers #9603

Closed
jusevitch opened this issue Aug 17, 2024 · 5 comments
Closed

Unable to compile model with GATv2Conv layers #9603

jusevitch opened this issue Aug 17, 2024 · 5 comments
Assignees

Comments

@jusevitch
Copy link

jusevitch commented Aug 17, 2024

🐛 Describe the bug

I'm unable to use torch.compile to compile a simple model using GATv2Conv layers. A MWE is below.

(Edit: I get a similar error when using torch_geometric.compile().)

MWE

import torch
torch._dynamo.config.capture_dynamic_output_shape_ops = True

import torch.nn.functional as F
from torch_geometric.nn import GATv2Conv as conv
from torch_geometric.datasets import FakeDataset


class GNN(torch.nn.Module):
    def __init__(self, features, classes, hidden_width, layers):
        super().__init__()

        self.layers = torch.nn.ModuleList([conv(features, hidden_width, heads=1)])
        for i in range(layers-2):
            self.layers.append(conv(-1, hidden_width, heads=1))
        self.layers.append(conv(-1, classes))
        self.act = F.gelu

    def forward(self, x, edge_index):

        for l in self.layers[:-1]:
            x = l(x, edge_index)
            x = self.act(x)
        x = self.layers[-1](x, edge_index)

        return x


if __name__ == "__main__":

    num_channels = 2
    num_classes = 2

    model = GNN(num_channels, num_classes, 4, 4)
    model = torch.compile(model, dynamic=True, fullgraph=True)

    dataset = FakeDataset(num_channels=num_channels, num_classes=num_classes, task="node")

    for data in dataset:
        out = model(data.x, data.edge_index)
        print(out)
Error Output (click to expand)
Traceback (most recent call last):                                                                                                                                                                          
  File "/home/usevitch/code/python/neurosub/compile_bug_mwe.py", line 45, in <module>                                                                                                                       
    out = model(data.x, data.edge_index)                                                                                                                                                                    
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                    
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl                                                            
    return self._call_impl(*args, **kwargs)                                                                                                                                                                 
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                 
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl                                                                    
    return forward_call(*args, **kwargs)                                                                                                                                                                    
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                    
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 433, in _fn                                                                           
    return fn(*args, **kwargs)                                                                                                                                                                              
           ^^^^^^^^^^^^^^^^^^^                                                                                                                                                                              
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl                                                            
    return self._call_impl(*args, **kwargs)                                                                                                                                                                 
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                 
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl                                                                    
    return forward_call(*args, **kwargs)                                                                                                                                                                    
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                    
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 1116, in __call__                                                                  
    return self._torchdynamo_orig_callable(                                                                                                                                                                 
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                 
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 472, in __call__                                                                   
    return _compile(                                                                                                                                                                                        
           ^^^^^^^^^                                                                                                                                                                                        
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_utils_internal.py", line 84, in wrapper_function                                                                  
    return StrobelightCompileTimeProfiler.profile_compile_time(                                                                                                                                             
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                             
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_strobelight/compile_time_profiler.py", line 129, in profile_compile_time                                          
    return func(*args, **kwargs)                                                                                                                                                                            
           ^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                            
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/contextlib.py", line 81, in inner                                                                                                      
    return func(*args, **kwds)                                                                                                                                                                              
           ^^^^^^^^^^^^^^^^^^^                                                                                                                                                                              
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 817, in _compile                                                                   
    guarded_code = compile_inner(code, one_graph, hooks, transform)                                                                                                                                         
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                         
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 231, in time_wrapper                                                                       
    r = func(*args, **kwargs)                                                                                                                                                                               
        ^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                               
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 636, in compile_inner                                                              
    out_code = transform_code_object(code, transform)                                                                                                                                                       
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                       
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/bytecode_transformation.py", line 1185, in transform_code_object                                           
    transformations(instructions, code_options)                                                                                                                                                             
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 178, in _fn                                                                        
    return fn(*args, **kwargs)                                                                                                                                                                              
           ^^^^^^^^^^^^^^^^^^^                                                                                                                                                                              
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 582, in transform                                                                  
    tracer.run()
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2451, in run
    super().run()
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 893, in run
    while self.step():
          ^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 805, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 499, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2059, in CALL
    self.call_function(fn, args, kwargs)
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 743, in call_function
    self.push(fn.call_function(self, args, kwargs)) 
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/variables/nn_module.py", line 437, in call_function
    return tx.inline_user_function_return(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 749, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2666, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2782, in inline_call_
    tracer.run()
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 893, in run
    while self.step():
          ^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 805, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 499, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1500, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars)
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 743, in call_function
    self.push(fn.call_function(self, args, kwargs)) 
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 344, in call_function
    return super().call_function(tx, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 293, in call_function
    return super().call_function(tx, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 90, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 749, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2666, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2782, in inline_call_
    tracer.run()
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 893, in run
    while self.step():
          ^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 805, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 499, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2059, in CALL
    self.call_function(fn, args, kwargs)
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 743, in call_function
    self.push(fn.call_function(self, args, kwargs)) 
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 293, in call_function
    return super().call_function(tx, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 90, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 749, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2666, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2782, in inline_call_
    tracer.run()
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 893, in run
    while self.step():
          ^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 805, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 499, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2059, in CALL
    self.call_function(fn, args, kwargs)
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 743, in call_function
    self.push(fn.call_function(self, args, kwargs)) 
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/variables/user_defined.py", line 448, in call_function
    return super().call_function(tx, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/variables/base.py", line 301, in call_function
    unimplemented(f"call_function {self} {args} {kwargs}")
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/exc.py", line 221, in unimplemented
    raise Unsupported(msg)
torch._dynamo.exc.Unsupported: call_function UserDefinedClassVariable(<class 'torch_geometric.edge_index.EdgeIndex'>) [TensorVariable()] {'sparse_size': TupleVariable(), 'is_undirected': ConstantVariable()}

from user code:
   File "/home/usevitch/code/python/neurosub/compile_bug_mwe.py", line 27, in forward
    x = l(x, edge_index)
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch_geometric/nn/conv/gatv2_conv.py", line 285, in forward
    edge_index, edge_attr = add_self_loops(
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch_geometric/utils/loop.py", line 466, in add_self_loops
    loop_index = EdgeIndex(

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

Versions

Version Information:

PyTorch version: 2.4.0+cu121                                                                                                                                                                                
Is debug build: False                                                                                                                                                                                       
CUDA used to build PyTorch: 12.1                                                                                                                                                                            
ROCM used to build PyTorch: N/A                                                                                                                                                                             
                                                                                                                                                                                                            
OS: Ubuntu 22.04.4 LTS (x86_64)                                                                                                                                                                             
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0                                                                                                                                                          
Clang version: Could not collect                                                                                                                                                                            
CMake version: version 3.22.1                                                                                                                                                                               
Libc version: glibc-2.35                                                                                                                                                                                    
                                                                                                                                                                                                            
Python version: 3.11.9 | packaged by conda-forge | (main, Apr 19 2024, 18:36:13) [GCC 12.3.0] (64-bit runtime)                                                                                              
Python platform: Linux-6.5.0-44-generic-x86_64-with-glibc2.35                                                                                                                                               
Is CUDA available: True                                                                                                                                                                                     
CUDA runtime version: 12.2.140                                                                                                                                                                              
CUDA_MODULE_LOADING set to: LAZY                                                                                                                                                                            
GPU models and configuration:                                                                                                                                                                               
GPU 0: NVIDIA GeForce RTX 4090                                                                                                                                                                              
GPU 1: NVIDIA GeForce RTX 4090                                                                                                                                                                              
                                                                                                                                                                                                            
Nvidia driver version: 550.90.07                                                                                                                                                                            
cuDNN version: Could not collect                                                                                                                                                                            
HIP runtime version: N/A                                                                                                                                                                                    
MIOpen runtime version: N/A                                                                                                                                                                                 
Is XNNPACK available: True                                                                                                                                                                                  
                                                                                                                                                                                                            
CPU:                                                                                                                                                                                                        
Architecture:                       x86_64                                                                                                                                                                  
CPU op-mode(s):                     32-bit, 64-bit                                                                                                                                                          
Address sizes:                      48 bits physical, 48 bits virtual
Byte Order:                         Little Endian
CPU(s):                             64
On-line CPU(s) list:                0-63
Vendor ID:                          AuthenticAMD
Model name:                         AMD Ryzen Threadripper PRO 5975WX 32-Cores                                                                                                                              
CPU family:                         25                                                                                                                                                                      
Model:                              8
Thread(s) per core:                 2
Core(s) per socket:                 32
Socket(s):                          1
Stepping:                           2
Frequency boost:                    enabled
CPU max MHz:                        7006.6401
CPU min MHz:                        1800.0000
BogoMIPS:                           7186.86
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good 
nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a
 misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 invpcid_single hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 s
mep bmi2 erms invpcid cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin brs
 arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl umip pku ospke vaes vpclmulqdq rdpid overflow_recov succor sm
ca fsrm
Virtualization:                     AMD-V
L1d cache:                          1 MiB (32 instances)
L1i cache:                          1 MiB (32 instances)
L2 cache:                           16 MiB (32 instances)
L3 cache:                           128 MiB (4 instances)
NUMA node(s):                       1
NUMA node0 CPU(s):                  0-63
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Not affected
Vulnerability Retbleed:             Not affected
Vulnerability Spec rstack overflow: Vulnerable: Safe RET, no microcode
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Retpolines; IBPB conditional; IBRS_FW; STIBP always-on; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Not affected

Versions of relevant libraries:
[pip3] numpy==2.0.1
[pip3] torch==2.4.0
[pip3] torch_cluster==1.6.3+pt24cu121
[pip3] torch_geometric==2.5.3
[pip3] torch_scatter==2.1.2+pt24cu121
[pip3] torch_sparse==0.6.18+pt24cu121
[pip3] torch_spline_conv==1.2.2+pt24cu121
[pip3] torchaudio==2.4.0
[pip3] torchvision==0.19.0
[pip3] triton==3.0.0
[conda] numpy                     2.0.1                    pypi_0    pypi
[conda] torch                     2.4.0                    pypi_0    pypi
[conda] torch-cluster             1.6.3+pt24cu121          pypi_0    pypi
[conda] torch-geometric           2.5.3                    pypi_0    pypi
[conda] torch-scatter             2.1.2+pt24cu121          pypi_0    pypi
[conda] torch-sparse              0.6.18+pt24cu121          pypi_0    pypi
[conda] torch-spline-conv         1.2.2+pt24cu121          pypi_0    pypi
[conda] torchaudio                2.4.0                    pypi_0    pypi
[conda] torchvision               0.19.0                   pypi_0    pypi
[conda] triton                    3.0.0                    pypi_0    pypi
@akihironitta
Copy link
Member

#9007 added support for torch.compile with the tensor subclass EdgeIndex. Would you mind trying again with master? https://github.com/pyg-team/pytorch_geometric/#nightly-and-master

@jusevitch
Copy link
Author

After installing the master branch with pip install git+https://github.com/pyg-team/pytorch_geometric.git, the original error is gone but a new error is coming up. The new error seems to be an issue with TupleVariable(). Output is below.

Error output (Click to expand)
Traceback (most recent call last):                                                                                                                                                                                                                                                                                                                                                  
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 646, in proxy_args_kwargs                                                                                                                                                                                                                                          
    proxy_kwargs = {key: arg.as_proxy() for key, arg in kwargs.items()}                                                                                                                                                                                                                                                                                                             
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                                                                                                                                             
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 646, in <dictcomp>                                                                                                                                                                                                                                                 
    proxy_kwargs = {key: arg.as_proxy() for key, arg in kwargs.items()}                                                                                                                                                                                                                                                                                                             
                         ^^^^^^^^^^^^^^                                                                                                                                                                                                                                                                                                                                             
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/variables/misc.py", line 660, in as_proxy                                                                                                                                                                                                                                          
    return GetAttrVariable.create_getattr_proxy(self.obj.as_proxy(), self.name)                                                                                                                                                                                                                                                                                                     
                                                ^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                                                                                                                                                 
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/variables/base.py", line 253, in as_proxy                                                                                                                                                                                                                                          
    raise NotImplementedError(str(self))                                                                                                                                                                                                                                                                                                                                            
NotImplementedError: GetSetDescriptorVariable()                                                                                                                                                                                                                                                                                                                                     
                                                                                                                                                                                                                                                                                                                                                                                    
The above exception was the direct cause of the following exception:                                                                                                                                                                                                                                                                                                                
                                                                                                                                                                                                                                                                                                                                                                                    
Traceback (most recent call last):                                                                                                                                                                                                                                                                                                                                                  
  File "/home/usevitch/code/python/neurosub/compile_error_mwe.py", line 40, in <module>                                                                                                                                                                                                                                                                                             
    out = model(data.x, data.edge_index)                                                                                                                                                                                                                                                                                                                                            
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                                                                                                                                                                            
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl                                                                                                                                                                                                                                    
    return self._call_impl(*args, **kwargs)                                                                                                                                                                                                                                                                                                                                         
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                                                                                                                                                                         
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl                                                                                                                                                                                                                                            
    return forward_call(*args, **kwargs)                                                                                                                                                                                                                                                                                                                                            
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                                                                                                                                                                            
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 433, in _fn                                                                                                                                                                                                                                                   
    return fn(*args, **kwargs)                                                                                                                                                                                                                                                                                                                                                      
           ^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                                                                                                                                                                                      
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl                                                                                                                                                                                                                                    
    return self._call_impl(*args, **kwargs)                                                                                                                                                                                                                                                                                                                                         
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                                                                                                                                                                         
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl                                                                                                                                                                                                                                            
    return forward_call(*args, **kwargs)                                                                                                                                                                                                                                                                                                                                            
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                                                                                                                                                                            
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 1116, in __call__                                                                                                                                                                                                                                          
    return self._torchdynamo_orig_callable(                                                                                                                                                                                                                                                                                                                                         
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                                                                                                                                                                         
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 472, in __call__                                                                                                                                                                                                                                           
    return _compile(                                                                                                                                                                                                                                                                                                                                                                
           ^^^^^^^^^                                                                                                                                                                                                                                                                                                                                                                
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_utils_internal.py", line 84, in wrapper_function                                                                                                                                                                                                                                          
    return StrobelightCompileTimeProfiler.profile_compile_time(                                                                                                                                                                                                                                                                                                                     
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                                                                                                                                                     
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_strobelight/compile_time_profiler.py", line 129, in profile_compile_time                                                                                                                                                                                                                  
    return func(*args, **kwargs)                                                                                                                                                                                                                                                                                                                                                    
           ^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                                                                                                                                                                                    
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/contextlib.py", line 81, in inner                                                                                                                                                                                                                                                                              
    return func(*args, **kwds)                                                                                                                                                                                                                                                                                                                                                      
           ^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                                                                                                                                                                                      
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 817, in _compile                                                                                                                                                                                                                                           
    guarded_code = compile_inner(code, one_graph, hooks, transform)                                                                                                                                                                                                                                                                                                                 
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                                                                                                                                                 
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 231, in time_wrapper                                                                                                                                                                                                                                               
    r = func(*args, **kwargs)                                                                                                                                                                                                                                                                                                                                                       
        ^^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 636, in compile_inner
    out_code = transform_code_object(code, transform)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/bytecode_transformation.py", line 1185, in transform_code_object
    transformations(instructions, code_options)
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 178, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 582, in transform
    tracer.run()
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2451, in run
    super().run()
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 893, in run
    while self.step():
          ^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 805, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 499, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2059, in CALL
    self.call_function(fn, args, kwargs)
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 743, in call_function
    self.push(fn.call_function(self, args, kwargs))
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/variables/nn_module.py", line 437, in call_function
    return tx.inline_user_function_return(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 749, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2666, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2782, in inline_call_
    tracer.run()
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 893, in run
    while self.step():
          ^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 805, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 499, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1500, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars)
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 743, in call_function
    self.push(fn.call_function(self, args, kwargs))
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 344, in call_function
    return super().call_function(tx, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 293, in call_function
    return super().call_function(tx, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 90, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 749, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2666, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2782, in inline_call_
    tracer.run()
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 893, in run
    while self.step():
          ^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 805, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 499, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2059, in CALL
    self.call_function(fn, args, kwargs)
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 743, in call_function
    self.push(fn.call_function(self, args, kwargs))
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/variables/nn_module.py", line 838, in call_function
    return variables.UserFunctionVariable(fn, source=source).call_function(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 293, in call_function
    return super().call_function(tx, args, kwargs)
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 90, in call_function                                                                                                                                                                                                                                 
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 749, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2666, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2782, in inline_call_
    tracer.run()
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 893, in run
    while self.step():
          ^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 805, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 499, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2059, in CALL
    self.call_function(fn, args, kwargs)
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 743, in call_function
    self.push(fn.call_function(self, args, kwargs))
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 344, in call_function
    return super().call_function(tx, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 293, in call_function
    return super().call_function(tx, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 90, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 749, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2666, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2782, in inline_call_
    tracer.run()
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 893, in run
    while self.step():
          ^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 805, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 499, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1500, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars)
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 743, in call_function
    self.push(fn.call_function(self, args, kwargs))
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 293, in call_function
    return super().call_function(tx, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 90, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 749, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2666, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2782, in inline_call_
    tracer.run()
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 893, in run                                                                                                                                                                                                                                             
    while self.step():                                                                                                                                                                                                                                                                                                                                                              
          ^^^^^^^^^^^                                                                                                                                                                                                                                                                                                                                                               
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 805, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 499, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2059, in CALL
    self.call_function(fn, args, kwargs)
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 743, in call_function
    self.push(fn.call_function(self, args, kwargs))
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 344, in call_function
    return super().call_function(tx, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 293, in call_function
    return super().call_function(tx, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 90, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 749, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2666, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2782, in inline_call_
    tracer.run()
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 893, in run
    while self.step():
          ^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 805, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 499, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2059, in CALL
    self.call_function(fn, args, kwargs)
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 743, in call_function
    self.push(fn.call_function(self, args, kwargs))
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/variables/torch.py", line 762, in call_function
    *proxy_args_kwargs(args, kwargs),
     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 652, in proxy_args_kwargs
    unimplemented(
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/exc.py", line 220, in unimplemented
    raise Unsupported(msg) from from_exc
torch._dynamo.exc.Unsupported: call_function args: TupleVariable() GetAttrVariable(GetSetDescriptorVariable(), device) GetAttrVariable(GetSetDescriptorVariable(), dtype)

from user code:
   File "/home/usevitch/code/python/neurosub/compile_error_mwe.py", line 22, in forward
    x = l(x, edge_index)
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch_geometric/nn/conv/gatv2_conv.py", line 286, in forward
    x_l = self.lin_l(x).view(-1, H, C)
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1592, in _call_impl
    args_result = hook(self, args)
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch_geometric/nn/dense/linear.py", line 153, in initialize_parameters
    self.weight.materialize((self.out_channels, self.in_channels))
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/nn/parameter.py", line 124, in materialize
    self.data = torch.empty(shape, device=device, dtype=dtype)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

@akihironitta
Copy link
Member

Thank you for providing the repro and complete error message. A quick workaround is to initialize the parameters in advance to passing the model to torch.compile:

     model = GNN(num_channels, num_classes, 4, 4)
+    dataset = FakeDataset(num_channels=num_channels, num_classes=num_classes, task="node")
+    model(dataset[0].x, dataset[0].edge_index)
     model = torch.compile(model, dynamic=True, fullgraph=True)

-    dataset = FakeDataset(num_channels=num_channels, num_classes=num_classes, task="node")

We should decorate initialize_parameters with torch._dynamo.disable so that PyG users don't need to think about this.

@jusevitch
Copy link
Author

Thanks for this suggestion. Unfortunately this still doesn't run for me, but the error states Please file an issue on GitHub so the PyTorch team can add support for it. I'll file a bug in the PyTorch repo.

New MWE (Click to expand)
import torch
torch._dynamo.config.capture_dynamic_output_shape_ops = True

import torch.nn.functional as F
from torch_geometric.nn import GATv2Conv as conv
from torch_geometric.datasets import FakeDataset


class GNN(torch.nn.Module):
    def __init__(self, features, classes, hidden_width, layers):
        super().__init__()

        self.layers = torch.nn.ModuleList([conv(features, hidden_width, heads=1)])
        for i in range(layers-2):
            self.layers.append(conv(-1, hidden_width, heads=1))
        self.layers.append(conv(-1, classes))
        self.act = F.gelu

    def forward(self, x, edge_index):

        for l in self.layers[:-1]:
            x = l(x, edge_index)
            x = self.act(x)
        x = self.layers[-1](x, edge_index)

        return x


if __name__ == "__main__":

    num_channels = 2
    num_classes = 2

    model = GNN(num_channels, num_classes, 4, 4)
    dataset = FakeDataset(num_channels=num_channels, num_classes=num_classes, task="node")
    model(dataset[0].x, dataset[0].edge_index)
    model = torch.compile(model, dynamic=True, fullgraph=True)

    for data in dataset:
        out = model(data.x, data.edge_index)
        print(out)
Error output (Click to expand)
E0819 22:05:31.170000 138952784099136 torch/fx/experimental/recording.py:281] [0/0] failed while running defer_runtime_assert(*(Eq(s1 + u0, s1 + u3), '/home/usevitch/mambaforge/envs/neurosub_debug/lib/pyt
hon3.11/site-packages/torch/__init__.py:1318'), **{'fx_node': None})                                                                                                                                        
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0] failed to eagerly compile backwards for dynamic, suppressing in case backwards not needed   
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0] Traceback (most recent call last):                                                          
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]   File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_fu
nctorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 547, in aot_dispatch_autograd                                                                                                                  
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]     compiled_bw_func = aot_config.bw_compiler(                                              
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]                        ^^^^^^^^^^^^^^^^^^^^^^^                                              
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]   File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dy
namo/backends/common.py", line 47, in _wrapped_bw_compiler                                                                                                                                                  
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]     return disable(disable(bw_compiler)(*args, **kwargs))                                   
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                    
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]   File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dy
namo/eval_frame.py", line 600, in _fn                                                                                                                                                                       
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]     return fn(*args, **kwargs)                                                              
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]            ^^^^^^^^^^^^^^^^^^^                                                              
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]   File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_ut
ils_internal.py", line 84, in wrapper_function                                                                                                                                                              
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]     return StrobelightCompileTimeProfiler.profile_compile_time(                             
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                             
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]   File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_st
robelight/compile_time_profiler.py", line 129, in profile_compile_time                                                                                                                                      
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]     return func(*args, **kwargs)                                                            
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]            ^^^^^^^^^^^^^^^^^^^^^                                                            
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]   File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dy
namo/utils.py", line 231, in time_wrapper                                                                                                                                                                   
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]     r = func(*args, **kwargs)                                                               
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]         ^^^^^^^^^^^^^^^^^^^^^                                                               
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]   File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_in
ductor/compile_fx.py", line 1454, in bw_compiler                                                                                                                                                            
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]     return inner_compile(                                                                   
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]            ^^^^^^^^^^^^^^                                                                   
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]   File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dy
namo/repro/after_aot.py", line 84, in debug_wrapper                                                                                                                                                         
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]     inner_compiled_fn = compiler_fn(gm, example_inputs)                                     
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                     
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]   File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_in
ductor/debug.py", line 304, in inner                                                                                                                                                                        
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]     return fn(*args, **kwargs)                                                              
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]            ^^^^^^^^^^^^^^^^^^^                                                              
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]   File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/contextlib.py", line 81
, in inner                                                                                                                                                                                                  
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]     return func(*args, **kwds)                                                              
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]            ^^^^^^^^^^^^^^^^^^^                                                              
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]   File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/contextlib.py", line 81
, in inner
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]     return func(*args, **kwds)
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]            ^^^^^^^^^^^^^^^^^^^
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]   File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_in
ductor/compile_fx.py", line 527, in compile_fx_inner
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]     compiled_graph = fx_codegen_and_compile(
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]                      ^^^^^^^^^^^^^^^^^^^^^^^
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]   File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/contextlib.py", line 81
, in inner
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]     return func(*args, **kwds)
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]            ^^^^^^^^^^^^^^^^^^^
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]   File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_in
ductor/compile_fx.py", line 738, in fx_codegen_and_compile
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]     fake_mode = fake_tensor_prop(gm, example_inputs)
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]   File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_in
ductor/compile_fx.py", line 379, in fake_tensor_prop
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]     FakeTensorProp(gm, mode=fake_mode).propagate_dont_convert_inputs(
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]   File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/fx/
passes/fake_tensor_prop.py", line 69, in propagate_dont_convert_inputs
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]     return super().run(*args)
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]            ^^^^^^^^^^^^^^^^^^
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]   File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/fx/
interpreter.py", line 146, in run
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]     self.env[node] = self.run_node(node)
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]                      ^^^^^^^^^^^^^^^^^^^
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]   File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/fx/
passes/fake_tensor_prop.py", line 37, in run_node
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]     result = super().run_node(n)
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]              ^^^^^^^^^^^^^^^^^^^
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]   File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/fx/
interpreter.py", line 203, in run_node
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]     return getattr(self, n.op)(n.target, args, kwargs)
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]   File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/fx/
interpreter.py", line 275, in call_function
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]     return target(*args, **kwargs)
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]            ^^^^^^^^^^^^^^^^^^^^^^^
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]   File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_op
s.py", line 667, in __call__
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]     return self_._op(*args, **kwargs)
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]            ^^^^^^^^^^^^^^^^^^^^^^^^^^
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]   File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/uti
ls/_stats.py", line 21, in wrapper
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]     return fn(*args, **kwargs)
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]            ^^^^^^^^^^^^^^^^^^^
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]   File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_su
bclasses/fake_tensor.py", line 1061, in __torch_dispatch__
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]     return self.dispatch(func, types, args, kwargs)
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]   File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_su
bclasses/fake_tensor.py", line 1450, in dispatch
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]     return self._cached_dispatch_impl(func, types, args, kwargs)
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]   File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_su
bclasses/fake_tensor.py", line 1153, in _cached_dispatch_impl
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]     output = self._dispatch_impl(func, types, args, kwargs)
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]   File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_su
bclasses/fake_tensor.py", line 1671, in _dispatch_impl
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]     return maybe_propagate_real_tensors(fast_impl(self, *args, **kwargs))
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]                                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]   File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_su
bclasses/fake_impls.py", line 1062, in fast_binary_impl
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]     final_shape = infer_size(final_shape, shape)
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]   File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_su
bclasses/fake_impls.py", line 1016, in infer_size
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]     torch._check(
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]   File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/__i
nit__.py", line 1353, in _check
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]     _check_with(RuntimeError, cond, message)
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]   File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/__i
nit__.py", line 1318, in _check_with
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]     if expect_true(cond):
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]        ^^^^^^^^^^^^^^^^^
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]   File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/fx/
experimental/symbolic_shapes.py", line 946, in expect_true
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]     return a.node.expect_true(frame.f_code.co_filename, frame.f_lineno)
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]   File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/fx/
experimental/sym_node.py", line 435, in expect_true 
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]     return self.shape_env.defer_runtime_assert(
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]   File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/fx/
experimental/recording.py", line 245, in wrapper
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]     return fn(*args, **kwargs)
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]            ^^^^^^^^^^^^^^^^^^^
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]   File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/fx/
experimental/symbolic_shapes.py", line 5338, in defer_runtime_assert
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]     assert not self.runtime_asserts_frozen, expr
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0] AssertionError: Eq(s1 + u0, s1 + u3)
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0] 
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0] While executing %mul_21 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](ar
gs = (%gather, %index_19), kwargs = {})
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0] Original traceback:
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]   File "/home/usevitch/code/python/neurosub/compile_error_mwe.py", line 24, in forward
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]     x = self.layers[-1](x, edge_index)
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]   File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/nn/
modules/module.py", line 1562, in _call_impl
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]     return forward_call(*args, **kwargs)
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]   File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch_geo
metric/nn/conv/gatv2_conv.py", line 329, in forward 
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]     out = self.propagate(edge_index, x=(x_l, x_r), alpha=alpha)
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]   File "/tmp/torch_geometric.nn.conv.gatv2_conv_GATv2Conv_propagate_ch0xd74l.py", line 176, 
in propagate
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]     out = self.message(
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]   File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch_geo
metric/nn/conv/gatv2_conv.py", line 375, in message 
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0]     return x_j * alpha.unsqueeze(-1)
W0819 22:05:31.171000 138952784099136 torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:551] [0/0] 
Traceback (most recent call last):
  File "/home/usevitch/code/python/neurosub/compile_error_mwe.py", line 41, in <module>
    out = model(data.x, data.edge_index)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 433, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/code/python/neurosub/compile_error_mwe.py", line 19, in forward
    def forward(self, x, edge_index):
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/nn/modules/container.py", line 295, in __getitem__
    return self.__class__(list(self._modules.values())[idx])
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/nn/modules/container.py", line 281, in __init__
    self += modules
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/nn/modules/container.py", line 322, in __iadd__
    return self.extend(modules)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/nn/modules/container.py", line 399, in extend
    if not isinstance(modules, container_abcs.Iterable):
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 1116, in __call__
    return self._torchdynamo_orig_callable(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 472, in __call__
    return _compile(
           ^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_utils_internal.py", line 84, in wrapper_function                                                                  
    return StrobelightCompileTimeProfiler.profile_compile_time(                                                                                                                                             
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_strobelight/compile_time_profiler.py", line 129, in profile_compile_time
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/contextlib.py", line 81, in inner 
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 817, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 231, in time_wrapper
    r = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 636, in compile_inner
    out_code = transform_code_object(code, transform)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/bytecode_transformation.py", line 1185, in transform_code_object
    transformations(instructions, code_options)
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 178, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 582, in transform
    tracer.run()
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2451, in run
    super().run()
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 893, in run
    while self.step():
          ^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 805, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 499, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2059, in CALL
    self.call_function(fn, args, kwargs)
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 743, in call_function
    self.push(fn.call_function(self, args, kwargs)) 
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 665, in call_function
    unimplemented(msg)
  File "/home/usevitch/mambaforge/envs/neurosub_debug/lib/python3.11/site-packages/torch/_dynamo/exc.py", line 221, in unimplemented
    raise Unsupported(msg)
torch._dynamo.exc.Unsupported: Graph break due to unsupported Python builtin _abc._abc_instancecheck. Please file an issue on GitHub so the PyTorch team can add support for it. 

from user code:
   File "<frozen abc>", line 119, in __instancecheck__

@akihironitta
Copy link
Member

@jusevitch I ran the script again on PyG master with PyTorch nightly, and it worked without any issue. Can you retry with newer version of PyG and PyTorch?


We should decorate initialize_parameters with torch._dynamo.disable so that PyG users don't need to think about this.

This wouldn't work because disabling the region will produce a graph break anyway. I don't see any other solutions where users can call torch.copmile(..., fullgraph=True) before initializing parameters, and I think this is a limitation of the lazy parameter initialization in PyG. At least for now, users are expected to run torch.compile after initializing parameters.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants