Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update benchmarks to new API #150

Merged
merged 2 commits into from
Sep 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 21 additions & 22 deletions benchmarks/jax/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,12 @@ def sphericart_benchmark(
key = jax.random.PRNGKey(0)
xyz = jax.random.normal(key, (n_samples, 3), dtype=dtype)

sh_calculator = sphericart.jax.spherical_harmonics
if normalized:
sh_calculator = sphericart.jax.spherical_harmonics
else:
sh_calculator = sphericart.jax.solid_harmonics

sh_calculator_jit = jax.jit(sphericart.jax.spherical_harmonics, static_argnums=(1, 2))
sh_calculator_jit = jax.jit(sh_calculator, static_argnums=(1,))

print(
f"**** Timings for l_max={l_max}, n_samples={n_samples}, n_tries={n_tries}, "
Expand All @@ -52,7 +55,7 @@ def sphericart_benchmark(

for i in range(n_tries + warmup):
elapsed = -time.time()
sh_sphericart = sh_calculator(xyz, l_max, normalized)
sh_sphericart = sh_calculator(xyz, l_max)
elapsed += time.time()
time_noderi[i] = elapsed

Expand All @@ -69,7 +72,7 @@ def sphericart_benchmark(

for i in range(n_tries + warmup):
elapsed = -time.time()
sh_sphericart_jit = sh_calculator_jit(xyz, l_max, normalized)
sh_sphericart_jit = sh_calculator_jit(xyz, l_max)
elapsed += time.time()
time_noderi[i] = elapsed

Expand All @@ -82,15 +85,15 @@ def sphericart_benchmark(
if verbose:
print("Warm-up timings / sec.:\n", time_noderi[:warmup])

def scalar_output(xyz, l_max, normalized):
return jax.numpy.sum(sphericart.jax.spherical_harmonics(xyz, l_max, normalized))
def scalar_output(xyz, l_max):
return jax.numpy.sum(sh_calculator(xyz, l_max))

sh_grad = jax.jit(jax.grad(scalar_output), static_argnums=(1, 2))
sh_grad = jax.jit(jax.grad(scalar_output), static_argnums=(1,))

time_deri = np.zeros(n_tries + warmup)
for i in range(n_tries + warmup):
elapsed = -time.time()
sh_sphericart_grad_jit = sh_grad(xyz, l_max, normalized)
sh_sphericart_grad_jit = sh_grad(xyz, l_max)
elapsed += time.time()
time_deri[i] = elapsed

Expand All @@ -103,21 +106,19 @@ def scalar_output(xyz, l_max, normalized):
if verbose:
print("Warm-up timings / sec.:\n", time_deri[:warmup])

def single_scalar_output(x, l_max, normalized):
return jax.numpy.sum(sphericart.jax.spherical_harmonics(x, l_max, normalized))
def single_scalar_output(x, l_max):
return jax.numpy.sum(sh_calculator(x, l_max))

# Compute the Hessian for a single (3,) input
single_hessian = jax.hessian(single_scalar_output)

# Use vmap to vectorize the Hessian computation over the first axis
sh_hess = jax.jit(
jax.vmap(single_hessian, in_axes=(0, None, None)), static_argnums=(1, 2)
)
sh_hess = jax.jit(jax.vmap(single_hessian, in_axes=(0, None)), static_argnums=(1,))

time_deri = np.zeros(n_tries + warmup)
for i in range(n_tries + warmup):
elapsed = -time.time()
sh_sphericart_hess_jit = sh_hess(xyz, l_max, normalized)
sh_sphericart_hess_jit = sh_hess(xyz, l_max)
elapsed += time.time()
time_deri[i] = elapsed

Expand All @@ -133,17 +134,15 @@ def single_scalar_output(x, l_max, normalized):
# calculate a function of the spherical harmonics that returns an array
# and take its jacobian with respect to the input Cartesian coordinates,
# both in forward mode and in reverse mode
def array_output(xyz, l_max, normalized):
return jax.numpy.sum(
sphericart.jax.spherical_harmonics(xyz, l_max, normalized), axis=0
)
def array_output(xyz, l_max):
return jax.numpy.sum(sh_calculator(xyz, l_max), axis=0)

jacfwd = jax.jit(jax.jacfwd(array_output), static_argnums=(1, 2))
jacfwd = jax.jit(jax.jacfwd(array_output), static_argnums=(1,))

time_deri = np.zeros(n_tries + warmup)
for i in range(n_tries + warmup):
elapsed = -time.time()
sh_jacfwd_jit = jacfwd(xyz, l_max, normalized)
sh_jacfwd_jit = jacfwd(xyz, l_max)
elapsed += time.time()
time_deri[i] = elapsed

Expand All @@ -156,12 +155,12 @@ def array_output(xyz, l_max, normalized):
if verbose:
print("Warm-up timings / sec.:\n", time_deri[:warmup])

jacrev = jax.jit(jax.jacrev(array_output), static_argnums=(1, 2))
jacrev = jax.jit(jax.jacrev(array_output), static_argnums=(1,))

time_deri = np.zeros(n_tries + warmup)
for i in range(n_tries + warmup):
elapsed = -time.time()
sh_jacrev_jit = jacrev(xyz, l_max, normalized)
sh_jacrev_jit = jacrev(xyz, l_max)
elapsed += time.time()
time_deri[i] = elapsed

Expand Down
7 changes: 6 additions & 1 deletion benchmarks/pytorch/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,12 @@ def sphericart_benchmark(
warmup=16,
):
xyz = torch.randn((n_samples, 3), dtype=dtype, device=device)
sh_calculator = sphericart.torch.SphericalHarmonics(l_max, normalized=normalized)

if normalized:
sh_calculator = sphericart.torch.SphericalHarmonics(l_max)
else:
sh_calculator = sphericart.torch.SolidHarmonics(l_max)

omp_threads = sh_calculator.omp_num_threads()
print(
f"**** Timings for l_max={l_max}, n_samples={n_samples}, n_tries={n_tries}, "
Expand Down
11 changes: 8 additions & 3 deletions benchmarks/pytorch/benchmark_second_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,14 @@ def sphericart_benchmark(
warmup=16,
):
xyz = torch.randn((n_samples, 3), dtype=dtype, device=device, requires_grad=True)
sh_calculator = sphericart.torch.SphericalHarmonics(
l_max, normalized=normalized, backward_second_derivatives=True
)
if normalized:
sh_calculator = sphericart.torch.SphericalHarmonics(
l_max, backward_second_derivatives=True
)
else:
sh_calculator = sphericart.torch.SolidHarmonics(
l_max, backward_second_derivatives=True
)
omp_threads = sh_calculator.omp_num_threads()
print(
f"**** Timings for l_max={l_max}, n_samples={n_samples}, n_tries={n_tries}, "
Expand Down
Loading