Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TRAK on cuda version 11.8 have numerical issue #81

Open
enkeejunior1 opened this issue Feb 4, 2025 · 0 comments
Open

TRAK on cuda version 11.8 have numerical issue #81

enkeejunior1 opened this issue Feb 4, 2025 · 0 comments

Comments

@enkeejunior1
Copy link

enkeejunior1 commented Feb 4, 2025

As the title mentioned, TRAK is experiencing numerical issues on the cuda version 11.8.

  • Minimum reproducible code
import torch
from trak.projectors import ProjectionType, AbstractProjector, CudaProjector

grad_dim = int(1e6)
projector = CudaProjector(
    grad_dim=grad_dim,
    proj_dim=32768,
    seed=42, 
    proj_type=ProjectionType.normal,
    device='cuda:0',
    max_batch_size=8,
)
grad = torch.randn(8, grad_dim, device='cuda:0')
proj = projector.project(grad, model_id=0)
print(proj)

>>> tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0')
  • Q: How I installed TRAK?
conda create -n trak python=3.10.16
conda activate trak
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
conda install cuda-nvcc=11.8 -c nvidia -y
conda install -c nvidia cuda-toolkit=11.8 -y
export CUDA_HOME=$CONDA_PREFIX
export PYTHONPATH=$CONDA_PREFIX/lib/python3.x/site-packages:$PYTHONPATH
pip install traker[fast]==0.3.2
Enviroment setting
### Environment Info
Python 3.8.20
Package                  Version
------------------------ ------------
fast_jl                  0.1.3
filelock                 3.13.1
fsspec                   2024.6.1
Jinja2                   3.1.4
MarkupSafe               2.1.5
mpmath                   1.3.0
networkx                 3.0
numpy                    1.24.1
nvidia-cublas-cu11       11.11.3.6
nvidia-cuda-cupti-cu11   11.8.87
nvidia-cuda-nvrtc-cu11   11.8.89
nvidia-cuda-runtime-cu11 11.8.89
nvidia-cudnn-cu11        9.1.0.70
nvidia-cufft-cu11        10.9.0.58
nvidia-curand-cu11       10.3.0.86
nvidia-cusolver-cu11     11.4.1.48
nvidia-cusparse-cu11     11.7.5.86
nvidia-nccl-cu11         2.20.5
nvidia-nvtx-cu11         11.8.86
pillow                   10.2.0
pip                      24.2
setuptools               75.1.0
sympy                    1.13.1
torch                    2.4.1+cu118
torchvision              0.19.1+cu118
tqdm                     4.67.1
traker                   0.3.2
triton                   3.0.0
typing_extensions        4.12.2
wheel                    0.44.0
# packages in environment at /data/yonghyun/anaconda3/envs/trak_118:
#
# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                        main  
_openmp_mutex             5.1                       1_gnu  
bzip2                     1.0.8                h5eee18b_6  
ca-certificates           2024.12.31           h06a4308_0  
cuda-cccl_linux-64        12.8.55                       0    nvidia
cuda-command-line-tools   12.8.0                        0    nvidia
cuda-compiler             12.6.2                        0    nvidia
cuda-cudart               12.8.57                       0    nvidia
cuda-cudart-dev           12.8.57                       0    nvidia
cuda-cudart-dev_linux-64  12.8.57                       0    nvidia
cuda-cudart-static        12.8.57                       0    nvidia
cuda-cudart-static_linux-64 12.8.57                       0    nvidia
cuda-cudart_linux-64      12.8.57                       0    nvidia
cuda-cuobjdump            12.8.55                       0    nvidia
cuda-cupti                12.8.57                       0    nvidia
cuda-cupti-dev            12.8.57                       0    nvidia
cuda-cuxxfilt             12.8.55                       0    nvidia
cuda-documentation        12.4.127                      0    nvidia
cuda-driver-dev           12.8.57                       0    nvidia
cuda-driver-dev_linux-64  12.8.57                       0    nvidia
cuda-gdb                  12.8.55                       0    nvidia
cuda-libraries            12.8.0                        0    nvidia
cuda-libraries-dev        12.8.0                        0    nvidia
cuda-nsight               12.8.55                       0    nvidia
cuda-nvcc                 11.8.89                       0    nvidia
cuda-nvdisasm             12.8.55                       0    nvidia
cuda-nvml-dev             12.8.55                       0    nvidia
cuda-nvprof               12.8.57                       0    nvidia
cuda-nvprune              12.8.55                       0    nvidia
cuda-nvrtc                12.8.61                       0    nvidia
cuda-nvrtc-dev            12.8.61                       0    nvidia
cuda-nvtx                 12.8.55                       0    nvidia
cuda-nvvp                 12.8.57                       0    nvidia
cuda-opencl               12.8.55                       0    nvidia
cuda-opencl-dev           12.8.55                       0    nvidia
cuda-profiler-api         12.8.55                       0    nvidia
cuda-sanitizer-api        12.8.55                       0    nvidia
cuda-toolkit              11.8.0                        0    nvidia
cuda-tools                12.8.0                        0    nvidia
cuda-version              12.8                          3    nvidia
cuda-visual-tools         12.8.0                        0    nvidia
dbus                      1.13.18              hb2f20db_0  
expat                     2.6.4                h6a678d5_0  
fast-jl                   0.1.3                    pypi_0    pypi
filelock                  3.13.1                   pypi_0    pypi
fontconfig                2.14.1               h55d465d_3  
freetype                  2.12.1               h4a9f257_0  
fsspec                    2024.6.1                 pypi_0    pypi
gds-tools                 1.13.0.11                     0    nvidia
glib                      2.78.4               h6a678d5_0  
glib-tools                2.78.4               h6a678d5_0  
gmp                       6.3.0                h6a678d5_0  
icu                       73.1                 h6a678d5_0  
jinja2                    3.1.4                    pypi_0    pypi
ld_impl_linux-64          2.40                 h12ee557_0  
libcublas                 12.8.3.14                     0    nvidia
libcublas-dev             12.8.3.14                     0    nvidia
libcufft                  11.3.3.41                     0    nvidia
libcufft-dev              11.3.3.41                     0    nvidia
libcufile                 1.13.0.11                     0    nvidia
libcufile-dev             1.13.0.11                     0    nvidia
libcurand                 10.3.9.55                     0    nvidia
libcurand-dev             10.3.9.55                     0    nvidia
libcusolver               11.7.2.55                     0    nvidia
libcusolver-dev           11.7.2.55                     0    nvidia
libcusparse               12.5.7.53                     0    nvidia
libcusparse-dev           12.5.7.53                     0    nvidia
libffi                    3.4.4                h6a678d5_1  
libgcc-ng                 11.2.0               h1234567_1  
libglib                   2.78.4               hdc74915_0  
libgomp                   11.2.0               h1234567_1  
libiconv                  1.16                 h5eee18b_3  
libnpp                    12.3.3.65                     0    nvidia
libnpp-dev                12.3.3.65                     0    nvidia
libnvfatbin               12.8.55                       0    nvidia
libnvfatbin-dev           12.8.55                       0    nvidia
libnvjitlink              12.8.61                       1    nvidia
libnvjitlink-dev          12.8.61                       1    nvidia
libnvjpeg                 12.3.5.57                     0    nvidia
libnvjpeg-dev             12.3.5.57                     0    nvidia
libpng                    1.6.39               h5eee18b_0  
libstdcxx-ng              11.2.0               h1234567_1  
libuuid                   1.41.5               h5eee18b_0  
libxcb                    1.15                 h7f8727e_0  
libxkbcommon              1.0.1                h097e994_2  
libxml2                   2.13.5               hfdd30dd_0  
markupsafe                2.1.5                    pypi_0    pypi
mpmath                    1.3.0                    pypi_0    pypi
ncurses                   6.4                  h6a678d5_0  
networkx                  3.0                      pypi_0    pypi
nsight-compute            2025.1.0.14                   0    nvidia
nspr                      4.35                 h6a678d5_0  
nss                       3.89.1               h6a678d5_0  
numpy                     1.24.1                   pypi_0    pypi
nvidia-cublas-cu11        11.11.3.6                pypi_0    pypi
nvidia-cuda-cupti-cu11    11.8.87                  pypi_0    pypi
nvidia-cuda-nvrtc-cu11    11.8.89                  pypi_0    pypi
nvidia-cuda-runtime-cu11  11.8.89                  pypi_0    pypi
nvidia-cudnn-cu11         9.1.0.70                 pypi_0    pypi
nvidia-cufft-cu11         10.9.0.58                pypi_0    pypi
nvidia-curand-cu11        10.3.0.86                pypi_0    pypi
nvidia-cusolver-cu11      11.4.1.48                pypi_0    pypi
nvidia-cusparse-cu11      11.7.5.86                pypi_0    pypi
nvidia-nccl-cu11          2.20.5                   pypi_0    pypi
nvidia-nvtx-cu11          11.8.86                  pypi_0    pypi
ocl-icd                   2.3.2                h5eee18b_1  
openssl                   3.0.15               h5eee18b_0  
pcre2                     10.42                hebb0a14_1  
pillow                    10.2.0                   pypi_0    pypi
pip                       24.2             py38h06a4308_0  
python                    3.8.20               he870216_0  
readline                  8.2                  h5eee18b_0  
setuptools                75.1.0           py38h06a4308_0  
sqlite                    3.45.3               h5eee18b_0  
sympy                     1.13.1                   pypi_0    pypi
tk                        8.6.14               h39e8969_0  
torch                     2.4.1+cu118              pypi_0    pypi
torchvision               0.19.1+cu118             pypi_0    pypi
tqdm                      4.67.1                   pypi_0    pypi
traker                    0.3.2                    pypi_0    pypi
triton                    3.0.0                    pypi_0    pypi
typing-extensions         4.12.2                   pypi_0    pypi
wheel                     0.44.0           py38h06a4308_0  
xz                        5.4.6                h5eee18b_1  
zlib                      1.2.13               h5eee18b_1  

### NVIDIA Info
Wed Feb  5 15:12:37 2025       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.216.01             Driver Version: 535.216.01   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA H100 PCIe               Off | 00000000:2D:00.0 Off |                    0 |
| N/A   40C    P0              52W / 350W |     17MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
@enkeejunior1 enkeejunior1 changed the title TRAK on H100 have numerical issue TRAK on cuda version 11.8 have numerical issue Feb 4, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant