Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core Aten ops] Logs for fixing core aten ops coverage issues #5934

Closed
qihqi opened this issue Nov 29, 2023 · 3 comments
Closed

[Core Aten ops] Logs for fixing core aten ops coverage issues #5934

qihqi opened this issue Nov 29, 2023 · 3 comments

Comments

@qihqi
Copy link
Collaborator

qihqi commented Nov 29, 2023

Let's use this issue as space for sharing notes and steps for adding lowerings for missing core aten ops.

@qihqi
Copy link
Collaborator Author

qihqi commented Nov 29, 2023

Issue beging worked #5902

1. Uncomment and rerun the test

LD_LIBRARY_PATH=/mnt/hanq/miniconda3/envs/torch310/lib/:/usr/lib/x86_64-linux-gnu/ PJRT_DEVICE=CPU XLA_STABLEHLO_COMPILE=1 XLA_HLO_DEBUG=1 XLA_IR_DEBUG=1 pytest test/test_core_aten_ops.py -k test_aten_tan_1

output:

=========================== short test summary info ============================
[torch_xla_diff:0.001] SUBFAIL test/test_core_aten_ops.py::AtenOpTest::test_aten_tan_1 - AssertionError: False is not true
[stablehlo_diff: 0.001] SUBFAIL test/test_core_aten_ops.py::AtenOpTest::test_aten_tan_1 - AssertionError: False is not true
================= 2 failed, 1 passed, 514 deselected in 5.51s ==================
I0000 00:00:1700690393.569658 2513762 tfrt_cpu_pjrt_client.cc:352] TfrtCpuClient destroyed.
(torch310) hanq@hanq-compile-2:/mnt/hanq/git/qihqi/pytorch/xla$

This means that the accuracy is not good.

Break line here

(torch310) hanq@hanq-compile-2:/mnt/hanq/git/qihqi/pytorch/xla$ git diff
diff --git a/test/test_core_aten_ops.py b/test/test_core_aten_ops.py
index 46a18494d..ff055ee38 100644
--- a/test/test_core_aten_ops.py
+++ b/test/test_core_aten_ops.py
@@ -36,6 +36,7 @@ def run_export_and_compare(testcase, func, args, kwargs, atol=1e-3):
                                      lambda x: x.to(device=device), kwargs)
       res_xla = func(*args2, **kwargs2)
       with testcase.subTest('torch_xla_diff:' + str(atol)):
+        import pdb; pdb.set_trace()
         diff_output(testcase, res, res_xla, atol)

Rerun, print out the difference:

(Pdb) p res - res_xla.cpu()
tensor([[ 0.0000e+00,  0.0000e+00, -4.8828e-04,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  6.1035e-05,  0.0000e+00,  0.0000e+00],
        [-4.8828e-04,  0.0000e+00,  0.0000e+00,  9.7656e-04,  0.0000e+00,
          1.2207e-04,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 0.0000e+00, -1.5259e-05,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          4.8828e-04,  0.0000e+00,  0.0000e+00,  0.0000e+00,  1.2207e-04],
        [ 0.0000e+00,  2.4414e-04,  0.0000e+00, -1.9531e-03,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00, -3.0518e-05,  0.0000e+00],
        [ 0.0000e+00, -4.8828e-04, -2.4414e-04,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00, -6.1035e-05,  0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -1.9531e-03,
          0.0000e+00,  0.0000e+00,  1.9531e-03,  0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00, -1.9531e-03,  0.0000e+00,  0.0000e+00,
          2.4414e-04,  9.7656e-04,  1.2207e-04,  0.0000e+00,  0.0000e+00],
        [ 4.8828e-04,  0.0000e+00,  0.0000e+00, -7.8125e-03,  1.2207e-04,
         -9.7656e-04,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  1.5625e-02,  0.0000e+00,  0.0000e+00, -4.8828e-04,
         -1.2207e-04,  0.0000e+00,  0.0000e+00, -4.8828e-04, -3.9062e-03],
        [ 0.0000e+00, -1.2207e-04,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00]],
       dtype=torch.float16)

The result looks good enough; This means that probably we are being too strict in
test; setting a larger tolerance probably will work.

(Pdb) p torch.max(torch.abs(res - res_xla.cpu()))
tensor(0.0156, dtype=torch.float16)

printing out the difference shows that roughly 0.01 atol with a slightly larger
rtol probably work.

(Pdb) torch.allclose(res, res_xla.cpu(), atol=0.01, rtol=0.001)
True

Now it's time to PR:
#5915

wonjoolee95 added a commit that referenced this issue Nov 29, 2023
Add link to #5934 in our FIX_LOWERING_FOR_CORE_ATEN_OPS.md.
qihqi pushed a commit that referenced this issue Nov 29, 2023
Add link to #5934 in our FIX_LOWERING_FOR_CORE_ATEN_OPS.md.
ManfeiBai pushed a commit to ManfeiBai/PyTorchXLA that referenced this issue Dec 1, 2023
ManfeiBai pushed a commit to ManfeiBai/PyTorchXLA that referenced this issue Dec 1, 2023
@wonjoolee95
Copy link
Collaborator

Working on issue: #5934


Doing a quick check for the differences (res-res_xla), we can see that results are pretty much equal:

WONJOO: at diff_output, output1-output2_cpu=tensor([[        nan,         nan,         nan,         nan,  0.0000e+00,
          0.0000e+00,         nan,         nan,  0.0000e+00,         nan],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
                 nan,         nan,         nan,  0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,         nan,         nan,  0.0000e+00,  0.0000e+00,
                 nan,         nan,  0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,         nan,         nan,  0.0000e+00,         nan,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
        [        nan, -1.4901e-08,         nan,  0.0000e+00,  0.0000e+00,
          0.0000e+00,         nan,         nan,         nan,  0.0000e+00],
        [-3.7253e-09,  0.0000e+00,         nan,         nan,  0.0000e+00,
                 nan,         nan,  0.0000e+00,  2.9802e-08,  0.0000e+00],
        [        nan,         nan,  0.0000e+00,         nan,         nan,
          0.0000e+00,  0.0000e+00,  0.0000e+00,         nan,  0.0000e+00],
        [        nan,  1.1921e-07,         nan,         nan,         nan,
                 nan,         nan,  0.0000e+00,         nan,  0.0000e+00],
        [        nan,  0.0000e+00,         nan,  0.0000e+00,         nan,
                 nan,         nan,  0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,         nan,  0.0000e+00,
                 nan,  0.0000e+00,         nan,         nan,  0.0000e+00]])

But we see a bunch of nan's, which is expected for the aten_log op as ln(x) is undefined for x <= 0. And looking at the torch.allclose's documentation (https://pytorch.org/docs/stable/generated/torch.allclose.html), we can actually see that there is a flag called equal_nan that defaults to False. If set to true, this flag is consider two nan's as equal, which is what we want at least for this aten_log op.

Note that we should have this equal_nan to False by default. Only in these specific ops such as aten_log, we want to set this to true.

chunnienc pushed a commit to chunnienc/xla that referenced this issue Dec 14, 2023
golechwierowicz pushed a commit that referenced this issue Jan 12, 2024
Add link to #5934 in our FIX_LOWERING_FOR_CORE_ATEN_OPS.md.
@wonjoolee95
Copy link
Collaborator

Closing as all issues under this label have been resolved.

bhavya01 pushed a commit that referenced this issue Apr 22, 2024
Add link to #5934 in our FIX_LOWERING_FOR_CORE_ATEN_OPS.md.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants