Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TORCH_LIBRARY and m.def Not Working as Documented #92

Open
andylizf opened this issue May 14, 2024 · 1 comment
Open

TORCH_LIBRARY and m.def Not Working as Documented #92

andylizf opened this issue May 14, 2024 · 1 comment

Comments

@andylizf
Copy link

I encountered an issue where using TORCH_LIBRARY alone, without the dispatcher API, does not work as expected. According to the PyTorch documentation, the TORCH_LIBRARY macro should create a function that registers custom operators. However, when I follow this approach, I get the following error during runtime:

$ python test/benchmark.py cuda
Traceback (most recent call last):
  File "/home/lizhifei/extension-cpp/test/benchmark.py", line 48, in <module>
    new_h, new_C = LLTM(X, W, b, h, C)
                   ^^^^^^^^^^^^^^^^^^^
  File "/home/lizhifei/miniconda3/envs/extension-cpp/lib/python3.12/site-packages/extension_cpp/ops.py", line 11, in lltm
    return LLTMFunction.apply(input, weights, bias, old_h, old_cell)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lizhifei/miniconda3/envs/extension-cpp/lib/python3.12/site-packages/torch/autograd/function.py", line 598, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lizhifei/miniconda3/envs/extension-cpp/lib/python3.12/site-packages/extension_cpp/ops.py", line 17, in forward
    outputs = torch.ops.extension_cpp.lltm_forward.default(
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lizhifei/miniconda3/envs/extension-cpp/lib/python3.12/site-packages/torch/_ops.py", line 921, in __getattr__
    raise AttributeError(
AttributeError: '_OpNamespace' object has no attribute 'lltm_forward'

Here is a link to my modified repository where this issue can be reproduced: andylizf/extension-cpp.

Could you please help me understand why this is happening and how to resolve it? Thank you.

Environment Information
  • OS: Windows 11 23H2 22631.3527
  • PyTorch version: 2.3.0
  • How you installed PyTorch: conda
  • Python version: 3.12.3
  • CUDA/cuDNN version: CUDA 12.1, cuDNN 8.9.2
  • GPU models and configuration: NVIDIA GeForce RTX 3090
  • Conda Env:
# packages in environment at /home/lizhifei/miniconda3/envs/extension-cpp:
#
# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                 conda_forge    conda-forge
_openmp_mutex             4.5                       2_gnu    conda-forge
blas                      1.0                         mkl    conda-forge
brotli-python             1.1.0           py312h30efb56_1    conda-forge
bzip2                     1.0.8                hd590300_5    conda-forge
ca-certificates           2024.2.2             hbcca054_0    conda-forge
certifi                   2024.2.2           pyhd8ed1ab_0    conda-forge
charset-normalizer        3.3.2              pyhd8ed1ab_0    conda-forge
cuda                      12.1.0                        0    nvidia
cuda-cccl                 12.1.109                      0    nvidia/label/cuda-12.1.1
cuda-command-line-tools   12.1.1                        0    nvidia/label/cuda-12.1.1
cuda-compiler             12.1.1                        0    nvidia/label/cuda-12.1.1
cuda-cudart               12.1.105                      0    nvidia
cuda-cudart-dev           12.1.105                      0    nvidia/label/cuda-12.1.1
cuda-cudart-static        12.1.105                      0    nvidia/label/cuda-12.1.1
cuda-cuobjdump            12.1.111                      0    nvidia/label/cuda-12.1.1
cuda-cupti                12.1.105                      0    nvidia
cuda-cupti-static         12.1.105                      0    nvidia/label/cuda-12.1.1
cuda-cuxxfilt             12.1.105                      0    nvidia/label/cuda-12.1.1
cuda-demo-suite           12.1.105                      0    nvidia/label/cuda-12.1.1
cuda-documentation        12.1.105                      0    nvidia/label/cuda-12.1.1
cuda-driver-dev           12.1.105                      0    nvidia/label/cuda-12.1.1
cuda-gdb                  12.1.105                      0    nvidia/label/cuda-12.1.1
cuda-libraries            12.1.0                        0    nvidia
cuda-libraries-dev        12.1.0                        0    nvidia
cuda-libraries-static     12.1.1                        0    nvidia/label/cuda-12.1.1
cuda-nsight               12.1.105                      0    nvidia/label/cuda-12.1.1
cuda-nsight-compute       12.1.1                        0    nvidia/label/cuda-12.1.1
cuda-nvcc                 12.1.105                      0    nvidia/label/cuda-12.1.1
cuda-nvdisasm             12.1.105                      0    nvidia/label/cuda-12.1.1
cuda-nvml-dev             12.1.105                      0    nvidia/label/cuda-12.1.1
cuda-nvprof               12.1.105                      0    nvidia/label/cuda-12.1.1
cuda-nvprune              12.1.105                      0    nvidia/label/cuda-12.1.1
cuda-nvrtc                12.1.105                      0    nvidia
cuda-nvrtc-dev            12.1.105                      0    nvidia/label/cuda-12.1.1
cuda-nvrtc-static         12.1.105                      0    nvidia/label/cuda-12.1.1
cuda-nvtx                 12.1.105                      0    nvidia
cuda-nvvp                 12.1.105                      0    nvidia/label/cuda-12.1.1
cuda-opencl               12.4.127                      0    nvidia
cuda-opencl-dev           12.1.105                      0    nvidia/label/cuda-12.1.1
cuda-profiler-api         12.1.105                      0    nvidia/label/cuda-12.1.1
cuda-runtime              12.1.0                        0    nvidia
cuda-sanitizer-api        12.1.105                      0    nvidia/label/cuda-12.1.1
cuda-toolkit              12.1.0                        0    nvidia
cuda-tools                12.1.0                        0    nvidia
cuda-version              12.4                 h3060b56_3    conda-forge
cuda-visual-tools         12.1.0                        0    nvidia
extension-cpp             0.0.1                    pypi_0    pypi
ffmpeg                    4.3                  hf484d3e_0    pytorch
filelock                  3.14.0             pyhd8ed1ab_0    conda-forge
freetype                  2.12.1               h267a509_2    conda-forge
fsspec                    2024.3.1                 pypi_0    pypi
gds-tools                 1.6.1.9                       0    nvidia/label/cuda-12.1.1
gmp                       6.3.0                h59595ed_1    conda-forge
gnutls                    3.6.13               h85f3911_1    conda-forge
icu                       73.2                 h59595ed_0    conda-forge
idna                      3.7                pyhd8ed1ab_0    conda-forge
intel-openmp              2023.1.0         hdb19cb5_46306  
jinja2                    3.1.4              pyhd8ed1ab_0    conda-forge
jpeg                      9e                   h166bdaf_2    conda-forge
lame                      3.100             h166bdaf_1003    conda-forge
lcms2                     2.15                 hfd0df8a_0    conda-forge
ld_impl_linux-64          2.40                 h55db66e_0    conda-forge
lerc                      4.0.0                h27087fc_0    conda-forge
libblas                   3.9.0           1_h86c2bf4_netlib    conda-forge
libcblas                  3.9.0           5_h92ddd45_netlib    conda-forge
libcublas                 12.1.0.26                     0    nvidia
libcublas-dev             12.1.0.26                     0    nvidia
libcublas-static          12.4.5.8             hd3aeb46_1    conda-forge
libcufft                  11.0.2.4                      0    nvidia
libcufft-dev              11.0.2.4                      0    nvidia
libcufft-static           11.2.1.3             hd3aeb46_1    conda-forge
libcufile                 1.9.1.3                       0    nvidia
libcufile-dev             1.6.1.9                       0    nvidia/label/cuda-12.1.1
libcufile-static          1.6.1.9                       0    nvidia/label/cuda-12.1.1
libcurand                 10.3.5.147                    0    nvidia
libcurand-dev             10.3.2.106                    0    nvidia/label/cuda-12.1.1
libcurand-static          10.3.2.106                    0    nvidia/label/cuda-12.1.1
libcusolver               11.4.4.55                     0    nvidia
libcusolver-dev           11.4.4.55                     0    nvidia
libcusolver-static        11.6.1.9             hd3aeb46_1    conda-forge
libcusparse               12.0.2.55                     0    nvidia
libcusparse-dev           12.0.2.55                     0    nvidia
libcusparse-static        12.3.1.170           hd3aeb46_1    conda-forge
libdeflate                1.17                 h0b41bf4_0    conda-forge
libexpat                  2.6.2                h59595ed_0    conda-forge
libffi                    3.4.2                h7f98852_5    conda-forge
libgcc-ng                 13.2.0               h77fa898_7    conda-forge
libgfortran-ng            13.2.0               h69a702a_7    conda-forge
libgfortran5              13.2.0               hca663fb_7    conda-forge
libgomp                   13.2.0               h77fa898_7    conda-forge
libhwloc                  2.10.0          default_h2fb2949_1000    conda-forge
libiconv                  1.17                 hd590300_2    conda-forge
libjpeg-turbo             2.0.0                h9bf148f_0    pytorch
liblapack                 3.9.0           5_h92ddd45_netlib    conda-forge
libnpp                    12.0.2.50                     0    nvidia
libnpp-dev                12.0.2.50                     0    nvidia
libnpp-static             12.2.5.30            hd3aeb46_1    conda-forge
libnsl                    2.0.1                hd590300_0    conda-forge
libnvjitlink              12.1.105                      0    nvidia
libnvjitlink-dev          12.1.105                      0    nvidia/label/cuda-12.1.1
libnvjitlink-static       12.4.127             hd3aeb46_1    conda-forge
libnvjpeg                 12.1.1.14                     0    nvidia
libnvjpeg-dev             12.1.1.14                     0    nvidia
libnvjpeg-static          12.3.1.117           ha770c72_1    conda-forge
libnvvm-samples           12.1.105                      0    nvidia/label/cuda-12.1.1
libpng                    1.6.43               h2797004_0    conda-forge
libsqlite                 3.45.3               h2797004_0    conda-forge
libstdcxx-ng              13.2.0               hc0a3c3a_7    conda-forge
libtiff                   4.5.0                h6adf6a1_2    conda-forge
libuuid                   2.38.1               h0b41bf4_0    conda-forge
libwebp-base              1.4.0                hd590300_0    conda-forge
libxcrypt                 4.4.36               hd590300_1    conda-forge
libxml2                   2.12.6               h232c23b_2    conda-forge
libzlib                   1.2.13               hd590300_5    conda-forge
llvm-openmp               15.0.7               h0cdce71_0    conda-forge
markupsafe                2.1.5           py312h98912ed_0    conda-forge
mkl                       2023.1.0         h213fc3f_46344  
mpmath                    1.3.0              pyhd8ed1ab_0    conda-forge
ncurses                   6.5                  h59595ed_0    conda-forge
nettle                    3.6                  he412f7d_0    conda-forge
networkx                  3.3                pyhd8ed1ab_1    conda-forge
ninja                     1.11.1.1                 pypi_0    pypi
nsight-compute            2023.1.1.4                    0    nvidia/label/cuda-12.1.1
numpy                     1.26.4          py312heda63a1_0    conda-forge
openh264                  2.1.1                h780b84a_0    conda-forge
openjpeg                  2.5.0                hfec8fc6_2    conda-forge
openssl                   3.3.0                hd590300_0    conda-forge
pillow                    10.3.0          py312h5eee18b_0  
pip                       24.0               pyhd8ed1ab_0    conda-forge
pysocks                   1.7.1              pyha2e5f31_6    conda-forge
python                    3.12.3          hab00c5b_0_cpython    conda-forge
python_abi                3.12                    4_cp312    conda-forge
pytorch                   2.3.0           py3.12_cuda12.1_cudnn8.9.2_0    pytorch
pytorch-cuda              12.1                 ha16c6d3_5    pytorch
pytorch-mutex             1.0                        cuda    pytorch
pyyaml                    6.0.1           py312h98912ed_1    conda-forge
readline                  8.2                  h8228510_1    conda-forge
requests                  2.31.0             pyhd8ed1ab_0    conda-forge
setuptools                69.5.1             pyhd8ed1ab_0    conda-forge
sympy                     1.12               pyh04b8f61_3    conda-forge
tbb                       2021.12.0            h00ab1b0_0    conda-forge
tk                        8.6.13          noxft_h4845f30_101    conda-forge
torchaudio                2.3.0               py312_cu121    pytorch
torchvision               0.18.0              py312_cu121    pytorch
typing_extensions         4.11.0             pyha770c72_0    conda-forge
tzdata                    2024a                h0c530f3_0    conda-forge
urllib3                   2.2.1              pyhd8ed1ab_0    conda-forge
wheel                     0.43.0             pyhd8ed1ab_1    conda-forge
xz                        5.2.6                h166bdaf_0    conda-forge
yaml                      0.2.5                h7f98852_2    conda-forge
zlib                      1.2.13               hd590300_5    conda-forge
zstd                      1.5.6                ha6fb4c9_0    conda-forge
@andylizf andylizf changed the title TORCH_LIBRARY and m.def Not Working as Documented TORCH_LIBRARY and m.def Not Working as Documented May 14, 2024
@crazyboy9103
Copy link

crazyboy9103 commented Aug 20, 2024

https://github.com/andylizf/extension-cpp/blob/2d49e184f82ab6e1b61e8dc3abb6c7ede65ca37b/extension_cpp/csrc/lltm.cpp#L9

I think it should be TORCH_LIBRARY(extension_cpp, m), not TORCH_LIBRARY(TORCH_EXTENSION_NAME, m)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants