Skip to content

Commit

Permalink
Update benchmarks to new API (#150)
Browse files Browse the repository at this point in the history
  • Loading branch information
ceriottm authored Sep 7, 2024
1 parent 0b19d52 commit 74ec799
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 26 deletions.
43 changes: 21 additions & 22 deletions benchmarks/jax/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,12 @@ def sphericart_benchmark(
key = jax.random.PRNGKey(0)
xyz = jax.random.normal(key, (n_samples, 3), dtype=dtype)

sh_calculator = sphericart.jax.spherical_harmonics
if normalized:
sh_calculator = sphericart.jax.spherical_harmonics
else:
sh_calculator = sphericart.jax.solid_harmonics

sh_calculator_jit = jax.jit(sphericart.jax.spherical_harmonics, static_argnums=(1, 2))
sh_calculator_jit = jax.jit(sh_calculator, static_argnums=(1,))

print(
f"**** Timings for l_max={l_max}, n_samples={n_samples}, n_tries={n_tries}, "
Expand All @@ -52,7 +55,7 @@ def sphericart_benchmark(

for i in range(n_tries + warmup):
elapsed = -time.time()
sh_sphericart = sh_calculator(xyz, l_max, normalized)
sh_sphericart = sh_calculator(xyz, l_max)
elapsed += time.time()
time_noderi[i] = elapsed

Expand All @@ -69,7 +72,7 @@ def sphericart_benchmark(

for i in range(n_tries + warmup):
elapsed = -time.time()
sh_sphericart_jit = sh_calculator_jit(xyz, l_max, normalized)
sh_sphericart_jit = sh_calculator_jit(xyz, l_max)
elapsed += time.time()
time_noderi[i] = elapsed

Expand All @@ -82,15 +85,15 @@ def sphericart_benchmark(
if verbose:
print("Warm-up timings / sec.:\n", time_noderi[:warmup])

def scalar_output(xyz, l_max, normalized):
return jax.numpy.sum(sphericart.jax.spherical_harmonics(xyz, l_max, normalized))
def scalar_output(xyz, l_max):
return jax.numpy.sum(sh_calculator(xyz, l_max))

sh_grad = jax.jit(jax.grad(scalar_output), static_argnums=(1, 2))
sh_grad = jax.jit(jax.grad(scalar_output), static_argnums=(1,))

time_deri = np.zeros(n_tries + warmup)
for i in range(n_tries + warmup):
elapsed = -time.time()
sh_sphericart_grad_jit = sh_grad(xyz, l_max, normalized)
sh_sphericart_grad_jit = sh_grad(xyz, l_max)
elapsed += time.time()
time_deri[i] = elapsed

Expand All @@ -103,21 +106,19 @@ def scalar_output(xyz, l_max, normalized):
if verbose:
print("Warm-up timings / sec.:\n", time_deri[:warmup])

def single_scalar_output(x, l_max, normalized):
return jax.numpy.sum(sphericart.jax.spherical_harmonics(x, l_max, normalized))
def single_scalar_output(x, l_max):
return jax.numpy.sum(sh_calculator(x, l_max))

# Compute the Hessian for a single (3,) input
single_hessian = jax.hessian(single_scalar_output)

# Use vmap to vectorize the Hessian computation over the first axis
sh_hess = jax.jit(
jax.vmap(single_hessian, in_axes=(0, None, None)), static_argnums=(1, 2)
)
sh_hess = jax.jit(jax.vmap(single_hessian, in_axes=(0, None)), static_argnums=(1,))

time_deri = np.zeros(n_tries + warmup)
for i in range(n_tries + warmup):
elapsed = -time.time()
sh_sphericart_hess_jit = sh_hess(xyz, l_max, normalized)
sh_sphericart_hess_jit = sh_hess(xyz, l_max)
elapsed += time.time()
time_deri[i] = elapsed

Expand All @@ -133,17 +134,15 @@ def single_scalar_output(x, l_max, normalized):
# calculate a function of the spherical harmonics that returns an array
# and take its jacobian with respect to the input Cartesian coordinates,
# both in forward mode and in reverse mode
def array_output(xyz, l_max, normalized):
return jax.numpy.sum(
sphericart.jax.spherical_harmonics(xyz, l_max, normalized), axis=0
)
def array_output(xyz, l_max):
return jax.numpy.sum(sh_calculator(xyz, l_max), axis=0)

jacfwd = jax.jit(jax.jacfwd(array_output), static_argnums=(1, 2))
jacfwd = jax.jit(jax.jacfwd(array_output), static_argnums=(1,))

time_deri = np.zeros(n_tries + warmup)
for i in range(n_tries + warmup):
elapsed = -time.time()
sh_jacfwd_jit = jacfwd(xyz, l_max, normalized)
sh_jacfwd_jit = jacfwd(xyz, l_max)
elapsed += time.time()
time_deri[i] = elapsed

Expand All @@ -156,12 +155,12 @@ def array_output(xyz, l_max, normalized):
if verbose:
print("Warm-up timings / sec.:\n", time_deri[:warmup])

jacrev = jax.jit(jax.jacrev(array_output), static_argnums=(1, 2))
jacrev = jax.jit(jax.jacrev(array_output), static_argnums=(1,))

time_deri = np.zeros(n_tries + warmup)
for i in range(n_tries + warmup):
elapsed = -time.time()
sh_jacrev_jit = jacrev(xyz, l_max, normalized)
sh_jacrev_jit = jacrev(xyz, l_max)
elapsed += time.time()
time_deri[i] = elapsed

Expand Down
7 changes: 6 additions & 1 deletion benchmarks/pytorch/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,12 @@ def sphericart_benchmark(
warmup=16,
):
xyz = torch.randn((n_samples, 3), dtype=dtype, device=device)
sh_calculator = sphericart.torch.SphericalHarmonics(l_max, normalized=normalized)

if normalized:
sh_calculator = sphericart.torch.SphericalHarmonics(l_max)
else:
sh_calculator = sphericart.torch.SolidHarmonics(l_max)

omp_threads = sh_calculator.omp_num_threads()
print(
f"**** Timings for l_max={l_max}, n_samples={n_samples}, n_tries={n_tries}, "
Expand Down
11 changes: 8 additions & 3 deletions benchmarks/pytorch/benchmark_second_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,14 @@ def sphericart_benchmark(
warmup=16,
):
xyz = torch.randn((n_samples, 3), dtype=dtype, device=device, requires_grad=True)
sh_calculator = sphericart.torch.SphericalHarmonics(
l_max, normalized=normalized, backward_second_derivatives=True
)
if normalized:
sh_calculator = sphericart.torch.SphericalHarmonics(
l_max, backward_second_derivatives=True
)
else:
sh_calculator = sphericart.torch.SolidHarmonics(
l_max, backward_second_derivatives=True
)
omp_threads = sh_calculator.omp_num_threads()
print(
f"**** Timings for l_max={l_max}, n_samples={n_samples}, n_tries={n_tries}, "
Expand Down

0 comments on commit 74ec799

Please sign in to comment.