We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
import jax.numpy as jnp import jax import flax.linen as nn from collections.abc import Callable import time class rbfnet(nn.Module): pointnums: int = 100 areafun: Callable = nn.silu init_value: jax.Array = jnp.ones((100, 3)) def setup(self): self.position = self.param('position', lambda rng, init_value: init_value, self.init_value) # (pointnums, dim) self.weight = self.param('weight', nn.initializers.zeros, (self.pointnums,)) # (pointnums) def __call__(self, x: jax.Array): batch = x.shape[0] x = jnp.expand_dims(x, axis=1).repeat(self.pointnums, axis=1) # (batch, pointnums, dim) # print(x.shape) position = jnp.expand_dims(self.position, axis=0).repeat(batch, axis=0) # (batch, pointnums, dim) position = self.areafun(position) # print(position.shape) distance = jnp.linalg.norm(x - position, ord=2, axis=-1) # (batch, pointnums) output =(1 / distance) @ self.weight return output def main(): x = jnp.linspace(2, 3, 50) y = jnp.linspace(2, 3, 50) z = jnp.linspace(2, 3, 50) X, Y, Z = jnp.meshgrid(x, y, z, indexing='ij') data = jnp.stack([X.reshape(-1), Y.reshape(-1), Z.reshape(-1)], axis=-1) # print(data.shape) model = rbfnet() variables = model.init(jax.random.key(0), data) # print(variables) @jax.jit def forward_and_backward(variables, x): # Compute the forward pass def loss_fn(variables, x): return jnp.mean(model.apply(variables, x)) loss = loss_fn(variables, x) # Compute gradients grads = jax.grad(loss_fn)(variables, x) return loss, grads loss, grads = forward_and_backward(variables, data) if __name__ == '__main__': main()
get these errors after running the code above:
2024-11-15 10:47:13.733095: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1180] Results do not match the reference. This is likely a bug/unexpected loss of precision. E1115 10:47:13.735496 3712 buffer_comparator.cc:157] Difference at 0: inf, expected 2.00886 E1115 10:47:13.735510 3712 buffer_comparator.cc:157] Difference at 1: inf, expected 2.01074 E1115 10:47:13.735515 3712 buffer_comparator.cc:157] Difference at 2: inf, expected 2.01014 E1115 10:47:13.735519 3712 buffer_comparator.cc:157] Difference at 3: inf, expected 2.00902 E1115 10:47:13.735523 3712 buffer_comparator.cc:157] Difference at 4: inf, expected 2.01255 E1115 10:47:13.735539 3712 buffer_comparator.cc:157] Difference at 5: inf, expected 2.00876 E1115 10:47:13.735543 3712 buffer_comparator.cc:157] Difference at 6: inf, expected 2.01238 E1115 10:47:13.735546 3712 buffer_comparator.cc:157] Difference at 7: inf, expected 2.00943 E1115 10:47:13.735550 3712 buffer_comparator.cc:157] Difference at 8: inf, expected 2.01083 E1115 10:47:13.735554 3712 buffer_comparator.cc:157] Difference at 9: inf, expected 2.01033
too many to display all the errors, but other errors were same like these
however, when I changed the sizes of x, y, z smaller like:
x = jnp.linspace(2, 3, 5) y = jnp.linspace(2, 3, 5) z = jnp.linspace(2, 3, 5)
the errors disappeared
jax: 0.4.35 jaxlib: 0.4.34 numpy: 2.1.3 python: 3.10.15 (main, Oct 3 2024, 07:27:34) [GCC 11.2.0] device info: NVIDIA GeForce RTX 4090-1, 1 local devices" process_count: 1 platform: uname_result(system='Linux', node='eda', release='3.10.0-1160.el7.x86_64', version='#1 SMP Mon Oct 19 16:18:59 UTC 2020', machine='x86_64') $ nvidia-smi Fri Nov 15 11:14:06 2024 +---------------------------------------------------------------------------------------+ | NVIDIA-SMI 535.146.02 Driver Version: 535.146.02 CUDA Version: 12.2 | |-----------------------------------------+----------------------+----------------------+ | GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |=========================================+======================+======================| | 0 NVIDIA GeForce RTX 4090 Off | 00000000:01:00.0 Off | Off | | 0% 42C P2 34W / 450W | 18657MiB / 24564MiB | 0% Default | | | | N/A | +-----------------------------------------+----------------------+----------------------+ +---------------------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | | ID ID Usage | |=======================================================================================| | 0 N/A N/A 6628 G /usr/bin/X 42MiB | | 0 N/A N/A 12953 C ...y/anaconda3/envs/jax_kan/bin/python 18554MiB | | 0 N/A N/A 79518 G /usr/bin/gnome-shell 38MiB | +---------------------------------------------------------------------------------------+
The text was updated successfully, but these errors were encountered:
No branches or pull requests
Description
get these errors after running the code above:
too many to display all the errors, but other errors were same like these
however, when I changed the sizes of x, y, z smaller like:
the errors disappeared
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: