Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Results do not match the reference. This is likely a bug/unexpected loss of precision #24909

Open
yanboyang97 opened this issue Nov 15, 2024 · 0 comments
Labels
bug Something isn't working

Comments

@yanboyang97
Copy link

Description

import jax.numpy as jnp
import jax
import flax.linen as nn
from collections.abc import Callable
import time

class rbfnet(nn.Module):

    pointnums: int = 100
    areafun: Callable = nn.silu
    init_value: jax.Array = jnp.ones((100, 3))

    def setup(self):
        self.position = self.param('position', lambda rng, init_value: init_value, self.init_value) # (pointnums, dim)
        self.weight = self.param('weight', nn.initializers.zeros, (self.pointnums,)) # (pointnums)
        
    def __call__(self, x: jax.Array):
        batch = x.shape[0]
        x = jnp.expand_dims(x, axis=1).repeat(self.pointnums, axis=1) # (batch, pointnums, dim)
        # print(x.shape)
        position = jnp.expand_dims(self.position, axis=0).repeat(batch, axis=0) # (batch, pointnums, dim)
        position = self.areafun(position)
        # print(position.shape)
        distance = jnp.linalg.norm(x - position, ord=2, axis=-1) # (batch, pointnums)
        output =(1 / distance) @ self.weight
        return output

def main():
    x = jnp.linspace(2, 3, 50)
    y = jnp.linspace(2, 3, 50)
    z = jnp.linspace(2, 3, 50)
    X, Y, Z = jnp.meshgrid(x, y, z, indexing='ij')
    data = jnp.stack([X.reshape(-1), Y.reshape(-1), Z.reshape(-1)], axis=-1)
    # print(data.shape)

    model = rbfnet()
    variables = model.init(jax.random.key(0), data)
    # print(variables)

    @jax.jit
    def forward_and_backward(variables, x):
        # Compute the forward pass
        def loss_fn(variables, x):
            return jnp.mean(model.apply(variables, x))
        loss = loss_fn(variables, x)
        # Compute gradients
        grads = jax.grad(loss_fn)(variables, x)
        return loss, grads

    loss, grads = forward_and_backward(variables, data)

if __name__ == '__main__':
    main()

get these errors after running the code above:

2024-11-15 10:47:13.733095: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1180] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E1115 10:47:13.735496    3712 buffer_comparator.cc:157] Difference at 0: inf, expected 2.00886
E1115 10:47:13.735510    3712 buffer_comparator.cc:157] Difference at 1: inf, expected 2.01074
E1115 10:47:13.735515    3712 buffer_comparator.cc:157] Difference at 2: inf, expected 2.01014
E1115 10:47:13.735519    3712 buffer_comparator.cc:157] Difference at 3: inf, expected 2.00902
E1115 10:47:13.735523    3712 buffer_comparator.cc:157] Difference at 4: inf, expected 2.01255
E1115 10:47:13.735539    3712 buffer_comparator.cc:157] Difference at 5: inf, expected 2.00876
E1115 10:47:13.735543    3712 buffer_comparator.cc:157] Difference at 6: inf, expected 2.01238
E1115 10:47:13.735546    3712 buffer_comparator.cc:157] Difference at 7: inf, expected 2.00943
E1115 10:47:13.735550    3712 buffer_comparator.cc:157] Difference at 8: inf, expected 2.01083
E1115 10:47:13.735554    3712 buffer_comparator.cc:157] Difference at 9: inf, expected 2.01033

too many to display all the errors, but other errors were same like these

however, when I changed the sizes of x, y, z smaller like:

    x = jnp.linspace(2, 3, 5)
    y = jnp.linspace(2, 3, 5)
    z = jnp.linspace(2, 3, 5)

the errors disappeared

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.35
jaxlib: 0.4.34
numpy:  2.1.3
python: 3.10.15 (main, Oct  3 2024, 07:27:34) [GCC 11.2.0]
device info: NVIDIA GeForce RTX 4090-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='eda', release='3.10.0-1160.el7.x86_64', version='#1 SMP Mon Oct 19 16:18:59 UTC 2020', machine='x86_64')


$ nvidia-smi
Fri Nov 15 11:14:06 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.146.02             Driver Version: 535.146.02   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA GeForce RTX 4090        Off | 00000000:01:00.0 Off |                  Off |
|  0%   42C    P2              34W / 450W |  18657MiB / 24564MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A      6628      G   /usr/bin/X                                   42MiB |
|    0   N/A  N/A     12953      C   ...y/anaconda3/envs/jax_kan/bin/python    18554MiB |
|    0   N/A  N/A     79518      G   /usr/bin/gnome-shell                         38MiB |
+---------------------------------------------------------------------------------------+
@yanboyang97 yanboyang97 added the bug Something isn't working label Nov 15, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant