From 56d8ca3e7c96f1a65dde89f65185501c94f2a905 Mon Sep 17 00:00:00 2001 From: Michele Ceriotti Date: Fri, 6 Sep 2024 20:36:39 +0200 Subject: [PATCH 1/2] Update jax benchmark file --- benchmarks/jax/benchmark.py | 43 ++++++++++++++++++------------------- 1 file changed, 21 insertions(+), 22 deletions(-) diff --git a/benchmarks/jax/benchmark.py b/benchmarks/jax/benchmark.py index ec2a770f9..f17d70db1 100644 --- a/benchmarks/jax/benchmark.py +++ b/benchmarks/jax/benchmark.py @@ -37,9 +37,12 @@ def sphericart_benchmark( key = jax.random.PRNGKey(0) xyz = jax.random.normal(key, (n_samples, 3), dtype=dtype) - sh_calculator = sphericart.jax.spherical_harmonics + if normalized: + sh_calculator = sphericart.jax.spherical_harmonics + else: + sh_calculator = sphericart.jax.solid_harmonics - sh_calculator_jit = jax.jit(sphericart.jax.spherical_harmonics, static_argnums=(1, 2)) + sh_calculator_jit = jax.jit(sh_calculator, static_argnums=(1,)) print( f"**** Timings for l_max={l_max}, n_samples={n_samples}, n_tries={n_tries}, " @@ -52,7 +55,7 @@ def sphericart_benchmark( for i in range(n_tries + warmup): elapsed = -time.time() - sh_sphericart = sh_calculator(xyz, l_max, normalized) + sh_sphericart = sh_calculator(xyz, l_max) elapsed += time.time() time_noderi[i] = elapsed @@ -69,7 +72,7 @@ def sphericart_benchmark( for i in range(n_tries + warmup): elapsed = -time.time() - sh_sphericart_jit = sh_calculator_jit(xyz, l_max, normalized) + sh_sphericart_jit = sh_calculator_jit(xyz, l_max) elapsed += time.time() time_noderi[i] = elapsed @@ -82,15 +85,15 @@ def sphericart_benchmark( if verbose: print("Warm-up timings / sec.:\n", time_noderi[:warmup]) - def scalar_output(xyz, l_max, normalized): - return jax.numpy.sum(sphericart.jax.spherical_harmonics(xyz, l_max, normalized)) + def scalar_output(xyz, l_max): + return jax.numpy.sum(sh_calculator(xyz, l_max)) - sh_grad = jax.jit(jax.grad(scalar_output), static_argnums=(1, 2)) + sh_grad = jax.jit(jax.grad(scalar_output), static_argnums=(1,)) time_deri = np.zeros(n_tries + warmup) for i in range(n_tries + warmup): elapsed = -time.time() - sh_sphericart_grad_jit = sh_grad(xyz, l_max, normalized) + sh_sphericart_grad_jit = sh_grad(xyz, l_max) elapsed += time.time() time_deri[i] = elapsed @@ -103,21 +106,19 @@ def scalar_output(xyz, l_max, normalized): if verbose: print("Warm-up timings / sec.:\n", time_deri[:warmup]) - def single_scalar_output(x, l_max, normalized): - return jax.numpy.sum(sphericart.jax.spherical_harmonics(x, l_max, normalized)) + def single_scalar_output(x, l_max): + return jax.numpy.sum(sh_calculator(x, l_max)) # Compute the Hessian for a single (3,) input single_hessian = jax.hessian(single_scalar_output) # Use vmap to vectorize the Hessian computation over the first axis - sh_hess = jax.jit( - jax.vmap(single_hessian, in_axes=(0, None, None)), static_argnums=(1, 2) - ) + sh_hess = jax.jit(jax.vmap(single_hessian, in_axes=(0, None)), static_argnums=(1,)) time_deri = np.zeros(n_tries + warmup) for i in range(n_tries + warmup): elapsed = -time.time() - sh_sphericart_hess_jit = sh_hess(xyz, l_max, normalized) + sh_sphericart_hess_jit = sh_hess(xyz, l_max) elapsed += time.time() time_deri[i] = elapsed @@ -133,17 +134,15 @@ def single_scalar_output(x, l_max, normalized): # calculate a function of the spherical harmonics that returns an array # and take its jacobian with respect to the input Cartesian coordinates, # both in forward mode and in reverse mode - def array_output(xyz, l_max, normalized): - return jax.numpy.sum( - sphericart.jax.spherical_harmonics(xyz, l_max, normalized), axis=0 - ) + def array_output(xyz, l_max): + return jax.numpy.sum(sh_calculator(xyz, l_max), axis=0) - jacfwd = jax.jit(jax.jacfwd(array_output), static_argnums=(1, 2)) + jacfwd = jax.jit(jax.jacfwd(array_output), static_argnums=(1,)) time_deri = np.zeros(n_tries + warmup) for i in range(n_tries + warmup): elapsed = -time.time() - sh_jacfwd_jit = jacfwd(xyz, l_max, normalized) + sh_jacfwd_jit = jacfwd(xyz, l_max) elapsed += time.time() time_deri[i] = elapsed @@ -156,12 +155,12 @@ def array_output(xyz, l_max, normalized): if verbose: print("Warm-up timings / sec.:\n", time_deri[:warmup]) - jacrev = jax.jit(jax.jacrev(array_output), static_argnums=(1, 2)) + jacrev = jax.jit(jax.jacrev(array_output), static_argnums=(1,)) time_deri = np.zeros(n_tries + warmup) for i in range(n_tries + warmup): elapsed = -time.time() - sh_jacrev_jit = jacrev(xyz, l_max, normalized) + sh_jacrev_jit = jacrev(xyz, l_max) elapsed += time.time() time_deri[i] = elapsed From 586e8e33b73934c676cba0591c5deec782ef62a5 Mon Sep 17 00:00:00 2001 From: Michele Ceriotti Date: Fri, 6 Sep 2024 20:48:20 +0200 Subject: [PATCH 2/2] Update torch benchmarks --- benchmarks/pytorch/benchmark.py | 7 ++++++- benchmarks/pytorch/benchmark_second_derivatives.py | 11 ++++++++--- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/benchmarks/pytorch/benchmark.py b/benchmarks/pytorch/benchmark.py index 86c48918c..a2e023c2a 100644 --- a/benchmarks/pytorch/benchmark.py +++ b/benchmarks/pytorch/benchmark.py @@ -45,7 +45,12 @@ def sphericart_benchmark( warmup=16, ): xyz = torch.randn((n_samples, 3), dtype=dtype, device=device) - sh_calculator = sphericart.torch.SphericalHarmonics(l_max, normalized=normalized) + + if normalized: + sh_calculator = sphericart.torch.SphericalHarmonics(l_max) + else: + sh_calculator = sphericart.torch.SolidHarmonics(l_max) + omp_threads = sh_calculator.omp_num_threads() print( f"**** Timings for l_max={l_max}, n_samples={n_samples}, n_tries={n_tries}, " diff --git a/benchmarks/pytorch/benchmark_second_derivatives.py b/benchmarks/pytorch/benchmark_second_derivatives.py index e79db581e..0c985dd91 100644 --- a/benchmarks/pytorch/benchmark_second_derivatives.py +++ b/benchmarks/pytorch/benchmark_second_derivatives.py @@ -46,9 +46,14 @@ def sphericart_benchmark( warmup=16, ): xyz = torch.randn((n_samples, 3), dtype=dtype, device=device, requires_grad=True) - sh_calculator = sphericart.torch.SphericalHarmonics( - l_max, normalized=normalized, backward_second_derivatives=True - ) + if normalized: + sh_calculator = sphericart.torch.SphericalHarmonics( + l_max, backward_second_derivatives=True + ) + else: + sh_calculator = sphericart.torch.SolidHarmonics( + l_max, backward_second_derivatives=True + ) omp_threads = sh_calculator.omp_num_threads() print( f"**** Timings for l_max={l_max}, n_samples={n_samples}, n_tries={n_tries}, "