Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Outputs of torch.flatten abnormally mismatch on GPU when adding an intermediate result as output #380

Closed
Azyka opened this issue Nov 20, 2023 · 1 comment
Labels
bug Something isn't working

Comments

@Azyka
Copy link

Azyka commented Nov 20, 2023

Describe the bug
When adding the intermediate result of original output as an extra output in this model:

class Model0():
    def forward(self, *args):
        abs_1 = torch.abs(args[0])
        flatten = abs_1.flatten()
        return (flatten)

New:

class Model1():
    def forward(self, *args):
        abs_1 = torch.abs(args[0])
        flatten = abs_1.flatten()
        return (abs_1, flatten)

The output of torch.flatten is expected to be the same for the same input. However, it mismatched between the 2 models.
This mismatch is seen only on cuda.

To Reproduce
Repro script:

import numpy as np
import pickle
from numpy import testing
import torch

DEVICE='cuda'

class Model0(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, *args):
        abs_1 = torch.abs(args[0])
        flatten = abs_1.flatten()
        return (flatten)

model_0 = Model0()
output_names_0 = ['v0_0']

class Model1(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, *args):
        abs_1 = torch.abs(args[0])
        flatten = abs_1.flatten()
        return (abs_1, flatten)

model_1 = Model1()
output_names_1 = ['v5_0', 'v0_0']

data = np.array([6, 3, 4, 5, 4, 7, 5, 5, 3, 3, 4, 4, 6, 3, 7, 5, 4, 3, 5, 6, 3, 7,
       7, 5, 6, 6, 5, 4, 5, 6, 5, 3, 3, 5, 4, 5, 3, 7, 6, 6, 6, 4, 5, 3,
       7, 4, 4, 6, 5, 3, 7], dtype=np.int8)
input_data_0 = [data]

optmodel_0 = torch.compile(model_0, fullgraph=True, backend='hidet', mode=None)
model_out_0 = optmodel_0(*[torch.from_numpy(v).to(DEVICE) for v in input_data_0])
model_out_0 = [v.to(DEVICE).detach() for v in model_out_0] if isinstance(model_out_0, tuple) else [model_out_0.to(DEVICE).detach()]
model_out_0 = [v.cpu().resolve_conj().numpy() if v.is_conj() else v.cpu().numpy() for v in model_out_0]
output_0 = dict(zip(output_names_0, model_out_0))

input_data_1 = [data]

optmodel_1 = torch.compile(model_1, fullgraph=True, backend='hidet', mode=None)
model_out_1 = optmodel_1(*[torch.from_numpy(v).to(DEVICE) for v in input_data_1])
model_out_1 = [v.to(DEVICE).detach() for v in model_out_1] if isinstance(model_out_1, tuple) else [model_out_1.to(DEVICE).detach()]
model_out_1 = [v.cpu().resolve_conj().numpy() if v.is_conj() else v.cpu().numpy() for v in model_out_1]
output_1 = dict(zip(output_names_1, model_out_1))
output_name_dict = {'v0_0': 'v0_0'}

print('=========================')
try:
    for tensor_name_0, tensor_name_1 in output_name_dict.items():
        testing.assert_allclose(output_0[tensor_name_0], output_1[tensor_name_1], rtol=1, err_msg=f'at {tensor_name_0}, {tensor_name_1}')
    print("hidet does not trigger assertion")
except AssertionError as e:
    print("hidet triggers assertion")
    print(e)
print('=========================')

model_out_0 = model_0(*[torch.from_numpy(v).to(DEVICE) for v in input_data_0])
model_out_0 = [v.to(DEVICE).detach() for v in model_out_0] if isinstance(model_out_0, tuple) else [model_out_0.to(DEVICE).detach()]
model_out_0 = [v.cpu().resolve_conj().numpy() if v.is_conj() else v.cpu().numpy() for v in model_out_0]
output_0 = dict(zip(output_names_0, model_out_0))

model_out_1 = model_1(*[torch.from_numpy(v).to(DEVICE) for v in input_data_1])
model_out_1 = [v.to(DEVICE).detach() for v in model_out_1] if isinstance(model_out_1, tuple) else [model_out_1.to(DEVICE).detach()]
model_out_1 = [v.cpu().resolve_conj().numpy() if v.is_conj() else v.cpu().numpy() for v in model_out_1]
output_1 = dict(zip(output_names_1, model_out_1))

print('=========================')
try:
    for tensor_name_0, tensor_name_1 in output_name_dict.items():
        testing.assert_allclose(output_0[tensor_name_0], output_1[tensor_name_1], rtol=1, err_msg=f'at {tensor_name_0}, {tensor_name_1}')
    print("torch_eager does not trigger assertion")
except AssertionError as e:
    print("torch_eager triggers assertion")
    print(e)
print('=========================')

Output:

=========================
hidet triggers assertion

Not equal to tolerance rtol=1, atol=0
at v0_0, v0_0
Mismatched elements: 51 / 51 (100%)
Max absolute difference: 7
Max relative difference: inf
 x: array([6, 3, 4, 5, 4, 7, 5, 5, 3, 3, 4, 4, 6, 3, 7, 5, 4, 3, 5, 6, 3, 7,
       7, 5, 6, 6, 5, 4, 5, 6, 5, 3, 3, 5, 4, 5, 3, 7, 6, 6, 6, 4, 5, 3,
       7, 4, 4, 6, 5, 3, 7], dtype=int8)
 y: array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0], dtype=int8)
=========================
=========================
torch_eager does not trigger assertion
=========================

Expected behavior
The output of torch.flatten is expected to be the same for the same input.

Enviroment

  • OS: Ubuntu 22.04.3 LTS (x86_64)
  • GPU: RTX 1660
  • NVIDIA GPU Driver: 525.147.05
  • Hidet Version: 0.3.0
  • PyTorch Version: 2.1.0+cu118
@Azyka Azyka added the bug Something isn't working label Nov 20, 2023
@Azyka
Copy link
Author

Azyka commented Dec 6, 2023

Fixed in #384 , Thanks for you efforts on it! @Aalanli and @yaoyaoding

@Azyka Azyka closed this as completed Dec 6, 2023
vadiklyutiy added a commit that referenced this issue Dec 19, 2024
Save `Task` pickle in the translations cache. 

The reason - it is very convenient during performance analysis to get
smaller test case. Supporting scripts will come soon
vadiklyutiy added a commit that referenced this issue Dec 20, 2024
Save `Task` pickle in the translations cache. 

The reason - it is very convenient during performance analysis to get
smaller test case. Supporting scripts will come soon
vadiklyutiy added a commit that referenced this issue Dec 26, 2024
Save `Task` pickle in the translations cache. 

The reason - it is very convenient during performance analysis to get
smaller test case. Supporting scripts will come soon
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant