We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
As the title mentioned, TRAK is experiencing numerical issues on the cuda version 11.8.
import torch from trak.projectors import ProjectionType, AbstractProjector, CudaProjector grad_dim = int(1e6) projector = CudaProjector( grad_dim=grad_dim, proj_dim=32768, seed=42, proj_type=ProjectionType.normal, device='cuda:0', max_batch_size=8, ) grad = torch.randn(8, grad_dim, device='cuda:0') proj = projector.project(grad, model_id=0) print(proj) >>> tensor([[0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], ..., [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.]], device='cuda:0')
conda create -n trak python=3.10.16 conda activate trak pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 conda install cuda-nvcc=11.8 -c nvidia -y conda install -c nvidia cuda-toolkit=11.8 -y export CUDA_HOME=$CONDA_PREFIX export PYTHONPATH=$CONDA_PREFIX/lib/python3.x/site-packages:$PYTHONPATH pip install traker[fast]==0.3.2
### Environment Info Python 3.8.20 Package Version ------------------------ ------------ fast_jl 0.1.3 filelock 3.13.1 fsspec 2024.6.1 Jinja2 3.1.4 MarkupSafe 2.1.5 mpmath 1.3.0 networkx 3.0 numpy 1.24.1 nvidia-cublas-cu11 11.11.3.6 nvidia-cuda-cupti-cu11 11.8.87 nvidia-cuda-nvrtc-cu11 11.8.89 nvidia-cuda-runtime-cu11 11.8.89 nvidia-cudnn-cu11 9.1.0.70 nvidia-cufft-cu11 10.9.0.58 nvidia-curand-cu11 10.3.0.86 nvidia-cusolver-cu11 11.4.1.48 nvidia-cusparse-cu11 11.7.5.86 nvidia-nccl-cu11 2.20.5 nvidia-nvtx-cu11 11.8.86 pillow 10.2.0 pip 24.2 setuptools 75.1.0 sympy 1.13.1 torch 2.4.1+cu118 torchvision 0.19.1+cu118 tqdm 4.67.1 traker 0.3.2 triton 3.0.0 typing_extensions 4.12.2 wheel 0.44.0 # packages in environment at /data/yonghyun/anaconda3/envs/trak_118: # # Name Version Build Channel _libgcc_mutex 0.1 main _openmp_mutex 5.1 1_gnu bzip2 1.0.8 h5eee18b_6 ca-certificates 2024.12.31 h06a4308_0 cuda-cccl_linux-64 12.8.55 0 nvidia cuda-command-line-tools 12.8.0 0 nvidia cuda-compiler 12.6.2 0 nvidia cuda-cudart 12.8.57 0 nvidia cuda-cudart-dev 12.8.57 0 nvidia cuda-cudart-dev_linux-64 12.8.57 0 nvidia cuda-cudart-static 12.8.57 0 nvidia cuda-cudart-static_linux-64 12.8.57 0 nvidia cuda-cudart_linux-64 12.8.57 0 nvidia cuda-cuobjdump 12.8.55 0 nvidia cuda-cupti 12.8.57 0 nvidia cuda-cupti-dev 12.8.57 0 nvidia cuda-cuxxfilt 12.8.55 0 nvidia cuda-documentation 12.4.127 0 nvidia cuda-driver-dev 12.8.57 0 nvidia cuda-driver-dev_linux-64 12.8.57 0 nvidia cuda-gdb 12.8.55 0 nvidia cuda-libraries 12.8.0 0 nvidia cuda-libraries-dev 12.8.0 0 nvidia cuda-nsight 12.8.55 0 nvidia cuda-nvcc 11.8.89 0 nvidia cuda-nvdisasm 12.8.55 0 nvidia cuda-nvml-dev 12.8.55 0 nvidia cuda-nvprof 12.8.57 0 nvidia cuda-nvprune 12.8.55 0 nvidia cuda-nvrtc 12.8.61 0 nvidia cuda-nvrtc-dev 12.8.61 0 nvidia cuda-nvtx 12.8.55 0 nvidia cuda-nvvp 12.8.57 0 nvidia cuda-opencl 12.8.55 0 nvidia cuda-opencl-dev 12.8.55 0 nvidia cuda-profiler-api 12.8.55 0 nvidia cuda-sanitizer-api 12.8.55 0 nvidia cuda-toolkit 11.8.0 0 nvidia cuda-tools 12.8.0 0 nvidia cuda-version 12.8 3 nvidia cuda-visual-tools 12.8.0 0 nvidia dbus 1.13.18 hb2f20db_0 expat 2.6.4 h6a678d5_0 fast-jl 0.1.3 pypi_0 pypi filelock 3.13.1 pypi_0 pypi fontconfig 2.14.1 h55d465d_3 freetype 2.12.1 h4a9f257_0 fsspec 2024.6.1 pypi_0 pypi gds-tools 1.13.0.11 0 nvidia glib 2.78.4 h6a678d5_0 glib-tools 2.78.4 h6a678d5_0 gmp 6.3.0 h6a678d5_0 icu 73.1 h6a678d5_0 jinja2 3.1.4 pypi_0 pypi ld_impl_linux-64 2.40 h12ee557_0 libcublas 12.8.3.14 0 nvidia libcublas-dev 12.8.3.14 0 nvidia libcufft 11.3.3.41 0 nvidia libcufft-dev 11.3.3.41 0 nvidia libcufile 1.13.0.11 0 nvidia libcufile-dev 1.13.0.11 0 nvidia libcurand 10.3.9.55 0 nvidia libcurand-dev 10.3.9.55 0 nvidia libcusolver 11.7.2.55 0 nvidia libcusolver-dev 11.7.2.55 0 nvidia libcusparse 12.5.7.53 0 nvidia libcusparse-dev 12.5.7.53 0 nvidia libffi 3.4.4 h6a678d5_1 libgcc-ng 11.2.0 h1234567_1 libglib 2.78.4 hdc74915_0 libgomp 11.2.0 h1234567_1 libiconv 1.16 h5eee18b_3 libnpp 12.3.3.65 0 nvidia libnpp-dev 12.3.3.65 0 nvidia libnvfatbin 12.8.55 0 nvidia libnvfatbin-dev 12.8.55 0 nvidia libnvjitlink 12.8.61 1 nvidia libnvjitlink-dev 12.8.61 1 nvidia libnvjpeg 12.3.5.57 0 nvidia libnvjpeg-dev 12.3.5.57 0 nvidia libpng 1.6.39 h5eee18b_0 libstdcxx-ng 11.2.0 h1234567_1 libuuid 1.41.5 h5eee18b_0 libxcb 1.15 h7f8727e_0 libxkbcommon 1.0.1 h097e994_2 libxml2 2.13.5 hfdd30dd_0 markupsafe 2.1.5 pypi_0 pypi mpmath 1.3.0 pypi_0 pypi ncurses 6.4 h6a678d5_0 networkx 3.0 pypi_0 pypi nsight-compute 2025.1.0.14 0 nvidia nspr 4.35 h6a678d5_0 nss 3.89.1 h6a678d5_0 numpy 1.24.1 pypi_0 pypi nvidia-cublas-cu11 11.11.3.6 pypi_0 pypi nvidia-cuda-cupti-cu11 11.8.87 pypi_0 pypi nvidia-cuda-nvrtc-cu11 11.8.89 pypi_0 pypi nvidia-cuda-runtime-cu11 11.8.89 pypi_0 pypi nvidia-cudnn-cu11 9.1.0.70 pypi_0 pypi nvidia-cufft-cu11 10.9.0.58 pypi_0 pypi nvidia-curand-cu11 10.3.0.86 pypi_0 pypi nvidia-cusolver-cu11 11.4.1.48 pypi_0 pypi nvidia-cusparse-cu11 11.7.5.86 pypi_0 pypi nvidia-nccl-cu11 2.20.5 pypi_0 pypi nvidia-nvtx-cu11 11.8.86 pypi_0 pypi ocl-icd 2.3.2 h5eee18b_1 openssl 3.0.15 h5eee18b_0 pcre2 10.42 hebb0a14_1 pillow 10.2.0 pypi_0 pypi pip 24.2 py38h06a4308_0 python 3.8.20 he870216_0 readline 8.2 h5eee18b_0 setuptools 75.1.0 py38h06a4308_0 sqlite 3.45.3 h5eee18b_0 sympy 1.13.1 pypi_0 pypi tk 8.6.14 h39e8969_0 torch 2.4.1+cu118 pypi_0 pypi torchvision 0.19.1+cu118 pypi_0 pypi tqdm 4.67.1 pypi_0 pypi traker 0.3.2 pypi_0 pypi triton 3.0.0 pypi_0 pypi typing-extensions 4.12.2 pypi_0 pypi wheel 0.44.0 py38h06a4308_0 xz 5.4.6 h5eee18b_1 zlib 1.2.13 h5eee18b_1 ### NVIDIA Info Wed Feb 5 15:12:37 2025 +---------------------------------------------------------------------------------------+ | NVIDIA-SMI 535.216.01 Driver Version: 535.216.01 CUDA Version: 12.2 | |-----------------------------------------+----------------------+----------------------+ | GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |=========================================+======================+======================| | 0 NVIDIA H100 PCIe Off | 00000000:2D:00.0 Off | 0 | | N/A 40C P0 52W / 350W | 17MiB / 81559MiB | 0% Default | | | | Disabled | +-----------------------------------------+----------------------+----------------------+
The text was updated successfully, but these errors were encountered:
No branches or pull requests
As the title mentioned, TRAK is experiencing numerical issues on the cuda version 11.8.
Enviroment setting
The text was updated successfully, but these errors were encountered: