Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some model tests are failing on GPU #7098

Open
YosuaMichael opened this issue Jan 17, 2023 · 2 comments
Open

Some model tests are failing on GPU #7098

YosuaMichael opened this issue Jan 17, 2023 · 2 comments

Comments

@YosuaMichael
Copy link
Contributor

YosuaMichael commented Jan 17, 2023

Currently the some model tests are failing on Linux GPU on GHA.

Error observations:

Here are sample of the error from a run in 17 January 2023:

FAILED test/test_models.py::test_classification_model[cuda-resnet101] - AssertionError: Tensor-likes are not close!

Mismatched elements: 14 / 50 (28.0%)
Greatest absolute difference: 9.2578125 at index (0, 29) (up to 0.001 allowed)
Greatest relative difference: 0.16600049957503943 at index (0, 22) (up to 0.001 allowed)

After tracing back, seems like the problem start from around 8 or 9 December 2022. We notice in 8 December 2022 the run was succeeded, however it skip the GPU test and only run CPU test (example of 8 December 2022 run).

test/test_models.py::test_classification_model[cpu-wide_resnet50_2] PASSED [ 78%]
test/test_models.py::test_classification_model[cuda-alexnet] SKIPPED     [ 78%]

And on 9 December 2022, we notice it run both CPU and GPU test and the GPU test failed by having different result from the CPU counterpart (example run on 9 December 2022, notice that the failure on resnet101 has different relative difference with the one on 17 January 2023).

FAILED test/test_models.py::test_classification_model[cuda-resnet152] - AssertionError: Tensor-likes are not close!

Mismatched elements: 9 / 50 (18.0%)
Greatest absolute difference: 10275.0 at index (0, 23) (up to 0.001 allowed)
Greatest relative difference: 0.042592364818058275 at index (0, 15) (up to 0.001 allowed)

Another observation is on 9 December 2022 if we see the PR #6919, we can see that although the GHA linux GPU failed due to precision problem, the circle CI gpu test succeed.

There is not change in the model (resnet34) and the test, and the CPU test always succeed between 8 December 2022 to 17 January 2023.

Possible problems

  • From [proto][ci] Try add GPU ci for prototype transforms #6919 it seems that the precision error might caused by moving to GHA (possibly different GPU or configs)
  • We also notice that the precision error changes between 9 December 2022 to 17 January 2023 and it become larger, hence there might be another changes that cause the precision error bigger, this might be due to some changes on pytorch core (we confirm this by running the script below)

Script to reproduce the problem

Here is a small script that able to reproduce the problems:

import torch
import torchvision
import random

def get_cpu_gpu_model_output_maxdiff(model_fn, seed):
    torch.manual_seed(seed)
    random.seed(seed)
    m_cpu = model_fn(num_classes=50).eval()
    m_gpu = model_fn(num_classes=50)
    m_gpu.load_state_dict(m_cpu.state_dict())
    m_gpu = m_gpu.to("cuda").eval()

    input_shape = (1, 3, 224, 224)
    x_cpu = torch.rand(input_shape)
    x_gpu = x_cpu.clone().to("cuda")
    y_cpu = m_cpu(x_cpu).squeeze(0)
    y_gpu = m_gpu(x_gpu).to("cpu").squeeze(0)

    abs_diff = torch.abs(y_gpu - y_cpu)
    max_abs_diff = torch.max(abs_diff)
    max_abs_idx = torch.argmax(abs_diff)
    max_rel_diff = torch.abs(max_abs_diff / y_cpu[max_abs_idx])
    max_val_gpu = torch.max(torch.abs(y_gpu))
    mean_val_gpu = torch.mean(torch.abs(y_gpu))
    prec = 1e-3
    pass_test = torch.allclose(y_gpu, y_cpu, atol=prec, rtol=prec)
    print(f"  [{seed}]max_abs_diff: {max_abs_diff},\tmax_rel_diff: {max_rel_diff},\tmax_val_gpu: {max_val_gpu},\tmean_val_gpu: {mean_val_gpu},\tpass_test: {pass_test}")


for model_fn in [torchvision.models.resnet.resnet34, torchvision.models.resnet.resnet101, torchvision.models.efficientnet.efficientnet_b0]:
    print(f"model_fn: {model_fn.__name__}")
    for seed in range(5):
        get_cpu_gpu_model_output_maxdiff(model_fn, seed)

When I ran this script on AWS Cluster with cuda 11.6 on python 3.8 (I provide the result of collect_env.py at the end of the section), I got the following output log:

model_fn: resnet34
  [0]max_abs_diff: 0.012034416198730469,        max_rel_diff: 0.0012628681724891067,    max_val_gpu: 35.31159210205078, mean_val_gpu: 10.925275802612305,       pass_test: False
  [1]max_abs_diff: 0.01442718505859375, max_rel_diff: 0.0007660656701773405,    max_val_gpu: 37.38212585449219, mean_val_gpu: 12.549612998962402,       pass_test: False
  [2]max_abs_diff: 0.029125213623046875,        max_rel_diff: 0.001757694175466895,     max_val_gpu: 130.0574188232422, mean_val_gpu: 29.450868606567383,       pass_test: False
  [3]max_abs_diff: 0.014329195022583008,        max_rel_diff: 0.004036294762045145,     max_val_gpu: 38.02964401245117, mean_val_gpu: 12.15386962890625,        pass_test: False
  [4]max_abs_diff: 0.017838478088378906,        max_rel_diff: 0.001571571920067072,     max_val_gpu: 43.16202163696289, mean_val_gpu: 14.939404487609863,       pass_test: False
model_fn: resnet101
  [0]max_abs_diff: 9.53857421875,       max_rel_diff: 0.0014522294513881207,    max_val_gpu: 27715.107421875,   mean_val_gpu: 9278.046875,      pass_test: False
  [1]max_abs_diff: 30.28759765625,      max_rel_diff: 0.006908989977091551,     max_val_gpu: 47344.68359375,    mean_val_gpu: 12955.2421875,    pass_test: False
  [2]max_abs_diff: 19.2783203125,       max_rel_diff: 0.0016507417894899845,    max_val_gpu: 46184.32421875,    mean_val_gpu: 20209.998046875,  pass_test: False
  [3]max_abs_diff: 16.796875,   max_rel_diff: 0.0008035682258196175,    max_val_gpu: 32151.07421875,    mean_val_gpu: 13038.66015625,   pass_test: False
  [4]max_abs_diff: 17.41796875, max_rel_diff: 0.0015470916405320168,    max_val_gpu: 28275.6484375,     mean_val_gpu: 11823.5322265625, pass_test: False
model_fn: efficientnet_b0
  [0]max_abs_diff: 2.8128270112420806e-16,      max_rel_diff: 0.0031724595464766026,    max_val_gpu: 2.260063203208748e-13,     mean_val_gpu: 9.397712642800204e-14,    pass_test: True
  [1]max_abs_diff: 3.2067989756690007e-16,      max_rel_diff: 0.0029657522682100534,    max_val_gpu: 3.657568497204833e-13,     mean_val_gpu: 1.1394578585798704e-13,   pass_test: True
  [2]max_abs_diff: 1.7404748296886638e-16,      max_rel_diff: 0.013579211197793484,     max_val_gpu: 1.936249012686464e-13,     mean_val_gpu: 7.33500039136123e-14,     pass_test: True
  [3]max_abs_diff: 1.8396200361647796e-16,      max_rel_diff: 0.0022520553320646286,    max_val_gpu: 3.433499957874314e-13,     mean_val_gpu: 9.2998471337008e-14,      pass_test: True
  [4]max_abs_diff: 2.201540273867597e-16,       max_rel_diff: 0.002613184042274952,     max_val_gpu: 3.3239963516049076e-13,    mean_val_gpu: 1.0326688912624601e-13,   pass_test: True

Our test we have tolerance of 0.001 and these results are consistently bigger than the usual tolerance for resnet models, hence it is unexpected.

There seems no change associated with resnet models in torchvision, hence most likely some changes in pytorch-core cause this differences.

The environment I used to run this reproduction:

OS: Ubuntu 18.04.6 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: 6.0.0-1ubuntu2 (tags/RELEASE_600/final)
CMake version: version 3.22.3
Libc version: glibc-2.27

Python version: 3.8.15 (default, Nov 24 2022, 15:19:38)  [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.4.0-1069-aws-x86_64-with-glibc2.17
Is CUDA available: True
CUDA runtime version: 11.6.112
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA A100-SXM4-40GB
Nvidia driver version: 510.47.03
cuDNN version: Probably one of the following:
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn.so.8.0.5
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.0.5
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.0.5
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.0.5
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.0.5
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.0.5
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.0.5
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn.so.8.0.5
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.0.5
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.0.5
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.0.5
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.0.5
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.0.5
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.0.5
/usr/local/cuda-11.2/targets/x86_64-linux/lib/libcudnn.so.8.1.1
/usr/local/cuda-11.2/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.1.1
/usr/local/cuda-11.2/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.1.1
/usr/local/cuda-11.2/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.1.1
/usr/local/cuda-11.2/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.1.1
/usr/local/cuda-11.2/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.1.1
/usr/local/cuda-11.2/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.1.1
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] mypy==0.991
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.23.4
[pip3] torch==2.0.0.dev20221221
[pip3] torchaudio==2.0.0.dev20221221
[pip3] torchvision==0.15.0a0+dca6617
[conda] blas                      1.0                         mkl
[conda] mkl                       2021.4.0           h06a4308_640
[conda] mkl-service               2.4.0            py38h7f8727e_0
[conda] mkl_fft                   1.3.1            py38hd3c417c_0
[conda] mkl_random                1.2.2            py38h51133e4_0
[conda] numpy                     1.23.4           py38h14f4228_0
[conda] numpy-base                1.23.4           py38h31eccc5_0
[conda] pytorch                   2.0.0.dev20221221 py3.8_cuda11.6_cudnn8.3.2_0    pytorch-nightly
[conda] pytorch-cuda              11.6                 h867d48c_2    pytorch-nightly
[conda] pytorch-mutex             1.0                        cuda    pytorch-nightly
[conda] torchaudio                2.0.0.dev20221221      py38_cu116    pytorch-nightly
[conda] torchtriton               2.0.0+0d7e753227            py38    pytorch-nightly
[conda] torchvision               0.15.0a0+dca6617           dev_0    <develop>

cc @osalpekar @seemethere @atalman

@YosuaMichael
Copy link
Contributor Author

YosuaMichael commented Jan 17, 2023

Looking at the experiment by running the script, seems like the model tests are very sensitive to seed (the relative differences between different seeds vary by a lot). Also, for some models such as efficientnet, their output will be close to zero if we use random input.

To address these two problem, I tried to do experiment with real pretrained weight and real image. Here is the script:

import torch
import torchvision
import random

from PIL import Image

img_path = "grace_hopper_517x606.jpg"
img_pil = Image.open(img_path)

def get_cpu_gpu_model_output_maxdiff(model_fn, seed):
    torch.manual_seed(seed)
    random.seed(seed)
    # Use real weight, we use the DEFAULT weight
    weight_enum = torchvision.models.get_model_weights(model_fn)
    weight = weight_enum.DEFAULT

    preprocess = weight.transforms()
    x_cpu = preprocess(img_pil).unsqueeze(0).to("cpu")
    x_gpu = preprocess(img_pil).unsqueeze(0).to("cuda")

    m_cpu = model_fn(weights=weight).eval()
    m_gpu = model_fn(weights=weight).cuda().eval()

    y_cpu = m_cpu(x_cpu).squeeze(0)
    y_gpu = m_gpu(x_gpu).to("cpu").squeeze(0)

    abs_diff = torch.abs(y_gpu - y_cpu)
    max_abs_diff = torch.max(abs_diff)
    max_abs_idx = torch.argmax(abs_diff)
    max_rel_diff = torch.abs(max_abs_diff / y_cpu[max_abs_idx])
    max_val_gpu = torch.max(torch.abs(y_gpu))
    mean_val_gpu = torch.mean(torch.abs(y_gpu))
    prec = 1e-3
    pass_test = torch.allclose(y_gpu, y_cpu, atol=prec, rtol=prec)
    print(f"  [{seed}]max_abs_diff: {max_abs_diff},\tmax_rel_diff: {max_rel_diff},\tmax_val_gpu: {max_val_gpu},\tmean_val_gpu: {mean_val_gpu},\tpass_test: {pass_test}")


for model_fn in [torchvision.models.resnet.resnet34, torchvision.models.resnet.resnet101, torchvision.models.efficientnet.efficientnet_b0]:
    print(f"model_fn: {model_fn.__name__}")
    for seed in range(1):
        get_cpu_gpu_model_output_maxdiff(model_fn, seed)

I used the following image for the test: https://github.com/pytorch/vision/blob/main/test/assets/encode_jpeg/grace_hopper_517x606.jpg

Since we use real image and weight, there is no randomness, thats why I only use 1 seed (I have tried using multiple seeds and I can confirmed they will produce exactly same results).

Here is the output after using real weight and image:

model_fn: resnet34
  [0]max_abs_diff: 0.004706382751464844,        max_rel_diff: 0.0009836413664743304,    max_val_gpu: 11.03113079071045, mean_val_gpu: 1.7828288078308105,       pass_test: False
model_fn: resnet101
  [0]max_abs_diff: 0.006227970123291016,        max_rel_diff: 0.0012130645336583257,    max_val_gpu: 7.295019626617432, mean_val_gpu: 0.4557604491710663,       pass_test: False
model_fn: efficientnet_b0
  [0]max_abs_diff: 0.006254449486732483,        max_rel_diff: 0.03236928954720497,      max_val_gpu: 9.316937446594238, mean_val_gpu: 0.8847216963768005,       pass_test: False

Compared to random image and random weight, now we have more comparable absolute differences between efficientnet and resnet.

Note: The max_rel_diff is computed by considering index with max_abs_diff and compute the relative difference of this index.

Although the result has more consistency across models, but now all models dont pass the test with precision 0.001. This however suggest that currently there is a quite significant differences between CPU and GPU on the model result (from this result the absolute differences may be up to 0.0063).

Is this differences between CPU and GPU expected? If yes, I think our TorchVision test should use real weight and image, then relaxes the precision constraint for GPU. Otherwise, we should investigate the biggest factor that cause the differences. cc @NicolasHug

@YosuaMichael
Copy link
Contributor Author

I compiled the statistics on all classification models in torchvision: https://docs.google.com/spreadsheets/d/162nq0p0-7Be0nBzffyEMhMF5PohNQ6Ew1OPdCcCVsGY/edit#gid=790607453

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant