Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add additional dependencies for mypy in pre-commit #1292

Merged
merged 4 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,13 @@ repos:
rev: v1.5.1
hooks:
- id: mypy
# include dependencies that export types (i.e. have a py.typed file in the root module) so that they can be used
# by mypy in pre-commit
additional_dependencies:
- "jax==0.4.14"
- "numpy==1.23.5"
- "scipy==1.10.1"
- "matplotlib==3.7.1"

# Exclude custom_ops.py to work around clash with stub file when typechecking
exclude: '^timemachine/lib/custom_ops.py$'
Expand Down
4 changes: 2 additions & 2 deletions tests/nonbonded/test_nonbonded_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ def test_nonbonded_reference_jittable(num_atom_idxs, rng: np.random.Generator):

U_ref = Nonbonded(
N,
exclusion_idxs=jnp.zeros((0,)),
scale_factors=jnp.zeros((0, 2)),
exclusion_idxs=np.zeros((0,), dtype=np.int32),
scale_factors=np.zeros((0, 2)),
beta=1.0,
cutoff=0.1,
atom_idxs=np.arange(num_atom_idxs) if num_atom_idxs is not None else None,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_jax_nonbonded.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def generate_random_inputs(n_atoms, dim, instance_flags=difficult_instance_flags

min_dist = 0.1
conf, box = resolve_clashes(conf, box, min_dist=min_dist)
box = np.array(box)
box = jnp.array(box)

cutoff = 1.2
if instance_flags["randomize_cutoff"]:
Expand Down
3 changes: 2 additions & 1 deletion timemachine/fe/energy_decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from timemachine.constants import BOLTZ, DEFAULT_TEMP
from timemachine.lib.custom_ops import Potential
from timemachine.potentials.types import Params

Frames = TypeVar("Frames")
Boxes = List[NDArray]
Expand All @@ -25,7 +26,7 @@ class EnergyDecomposedState(Generic[Frames]):

def get_batch_u_fns(
pots: Sequence[Potential],
params: Sequence[NDArray],
params: Sequence[Params],
temperature: float = DEFAULT_TEMP,
) -> List[Batch_u_fn]:
"""Get a list of functions that take in (coords, boxes), return reduced_potentials
Expand Down
1 change: 1 addition & 0 deletions timemachine/fe/free_energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,7 @@ def get_water_sampler_params(initial_state: InitialState) -> NDArray:
assert isinstance(summed_pot.potentials[ixn_group_idx], NonbondedInteractionGroup)
water_params = summed_pot.params_init[ixn_group_idx]
assert water_params.shape[1] == 4
water_params = np.asarray(water_params)
return water_params


Expand Down
2 changes: 1 addition & 1 deletion timemachine/fe/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ def plot_hrex_replica_state_distribution_convergence(cumulative_replica_state_co
ax.set_visible(False)

fig.subplots_adjust(right=0.8, hspace=0.2, wspace=0.2)
cbar_ax = fig.add_axes([0.85, 0.15, 0.02, 0.7])
cbar_ax = fig.add_axes((0.85, 0.15, 0.02, 0.7))
fig.colorbar(p, cax=cbar_ax, label=r"$\log_{10}$(number of iterations)")


Expand Down
35 changes: 21 additions & 14 deletions timemachine/fe/reweighting.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,36 +5,38 @@
"interpret_as_mixture_potential",
]

from typing import Any, Callable, Collection
from typing import Callable, Collection

import numpy as np
from jax import Array
from jax import numpy as jnp
from jax.scipy.special import logsumexp
from jax.typing import ArrayLike

Samples = Collection
Params = Collection
Array = Any # see https://github.com/google/jax/issues/943
Energies = Array

BatchedReducedPotentialFxn = Callable[[Samples, Params], Energies]


def log_mean(log_values: Array) -> float:
def log_mean(log_values: ArrayLike) -> Array:
"""stable log(mean(values))

log(mean(values))
= log(sum(values / len(values)))
= logsumexp(log(values) - log(len(values))
"""
log_values = jnp.asarray(log_values)
return logsumexp(log_values - jnp.log(len(log_values)))


def estimate_log_z_ratio(log_importance_weights: Array) -> float:
def estimate_log_z_ratio(log_importance_weights: ArrayLike) -> Array:
"""stable log(mean(importance_weights))"""
return log_mean(log_importance_weights)


def one_sided_exp(delta_us: Array) -> float:
def one_sided_exp(delta_us: ArrayLike) -> Array:
"""exponential averaging

References
Expand All @@ -44,10 +46,11 @@ def one_sided_exp(delta_us: Array) -> float:
"""
# delta_us = -log_importance_weights
# delta_f = -log_z_ratio
delta_us = jnp.asarray(delta_us)
return -estimate_log_z_ratio(-delta_us)


def interpret_as_mixture_potential(u_kn: Array, f_k: Array, N_k: Array) -> Array:
def interpret_as_mixture_potential(u_kn: ArrayLike, f_k: ArrayLike, N_k: ArrayLike) -> Array:
r"""Interpret samples from multiple states k as if they originate from a single state
defined as a weighted mixture:

Expand Down Expand Up @@ -103,6 +106,9 @@ def interpret_as_mixture_potential(u_kn: Array, f_k: Array, N_k: Array) -> Array
[3] [Elvira+, 2019] Generalized multiple importance sampling
https://arxiv.org/abs/1511.03095
"""
u_kn = jnp.asarray(u_kn)
f_k = jnp.asarray(f_k)

# one-liner: mixture_u_n = -logsumexp(f_k - u_kn.T, b=N_k, axis=1)

# expanding steps:
Expand Down Expand Up @@ -141,7 +147,7 @@ def construct_endpoint_reweighting_estimator(
batched_u_1_fxn: BatchedReducedPotentialFxn,
ref_params: Params,
ref_delta_f: float,
) -> Callable[[Params], float]:
) -> Callable[[Params], Array]:
"""assuming
* endpoint samples (samples_0, samples_1)
* precise estimate of free energy difference at initial params
Expand Down Expand Up @@ -181,17 +187,17 @@ def construct_endpoint_reweighting_estimator(
ref_u_0 = batched_u_0_fxn(samples_0, ref_params)
ref_u_1 = batched_u_1_fxn(samples_1, ref_params)

def endpoint_correction_0(params) -> float:
def endpoint_correction_0(params) -> Array:
"""estimate f(ref, 0) -> f(params, 0) by reweighting"""
delta_us = batched_u_0_fxn(samples_0, params) - ref_u_0
return one_sided_exp(delta_us)

def endpoint_correction_1(params) -> float:
def endpoint_correction_1(params) -> Array:
"""estimate f(ref, 1) -> f(params, 1) by reweighting"""
delta_us = batched_u_1_fxn(samples_1, params) - ref_u_1
return one_sided_exp(delta_us)

def estimate_delta_f(params: Params) -> float:
def estimate_delta_f(params: Params) -> Array:
"""estimate f(params, 1) - f(params, 0)

using this thermodynamic cycle:
Expand All @@ -218,10 +224,10 @@ def estimate_delta_f(params: Params) -> float:

def construct_mixture_reweighting_estimator(
samples_n: Samples,
u_ref_n: Array,
u_ref_n: ArrayLike,
batched_u_0_fxn: BatchedReducedPotentialFxn,
batched_u_1_fxn: BatchedReducedPotentialFxn,
) -> Callable[[Params], float]:
) -> Callable[[Params], Array]:
r"""assuming
* samples x_n from a distribution p_ref(x) \propto(exp(-u_ref(x))
that has good overlap with BOTH p_0(params)(x) and p_1(params)(x),
Expand Down Expand Up @@ -275,19 +281,20 @@ def construct_mixture_reweighting_estimator(
[3] Wieder et al. PyTorch implementation of differentiable reweighting in neutromeratio
https://github.com/choderalab/neutromeratio/blob/2abf29f03e5175a988503b5d6ceeee8ce5bfd4ad/neutromeratio/parameter_gradients.py#L246-L267
"""
u_ref_n = jnp.asarray(u_ref_n)
assert len(samples_n) == len(u_ref_n)

def f_0(params):
"""estimate f(params, 0) - f(ref) by reweighting"""
u_0_n = batched_u_0_fxn(samples_n, params)
return one_sided_exp(u_0_n - u_ref_n)

def f_1(params) -> float:
def f_1(params) -> Array:
"""estimate f(params, 1) - f(ref) by reweighting"""
u_1_n = batched_u_1_fxn(samples_n, params)
return one_sided_exp(u_1_n - u_ref_n)

def estimate_delta_f(params) -> float:
def estimate_delta_f(params) -> Array:
r"""estimate f(params, 1) - f(params, 0)

using this thermodynamic cycle:
Expand Down
3 changes: 1 addition & 2 deletions timemachine/fe/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import numpy as np
from numpy.typing import NDArray
from PIL import Image
from rdkit import Chem
from rdkit.Chem import AllChem, Draw
from rdkit.Chem.Draw import rdMolDraw2D
Expand Down Expand Up @@ -214,7 +213,7 @@ def plot_atom_mapping_grid(
core: NDArray,
num_rotations: int = 5,
seed: int = 1234,
) -> Image:
):
mol_a_3d = recenter_mol(mol_a)
mol_b_3d = recenter_mol(mol_b)

Expand Down
6 changes: 4 additions & 2 deletions timemachine/lib/fixed_point.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import jax.numpy as jnp
from jax import Array
from jax.typing import ArrayLike

from timemachine.lib import custom_ops


def fixed_to_float(v: int | jnp.uint64) -> jnp.float64:
def fixed_to_float(v: ArrayLike) -> Array:
"""Meant to imitate the logic of timemachine/cpp/src/fixed_point.hpp::FIXED_TO_FLOAT"""
return jnp.float64(jnp.int64(jnp.uint64(v))) / custom_ops.FIXED_EXPONENT


def float_to_fixed(v: jnp.float32 | float) -> jnp.uint64:
def float_to_fixed(v: ArrayLike) -> Array:
Comment on lines -6 to +13
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two seem more accurate but less precise

"""Meant to imitate the logic of timemachine/cpp/src/kernels/k_fixed_point.cuh::FLOAT_TO_FIXED"""
return jnp.uint64(jnp.int64(v * custom_ops.FIXED_EXPONENT))
2 changes: 1 addition & 1 deletion timemachine/maps/terminal_bonds.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def contains_in_support(self, x) -> bool:
interval = self.intervals[i]
bond_valid.append((r <= interval.upper) * (r >= interval.lower))

return jnp.array(bond_valid).all()
return jnp.array(bond_valid).all().item()

@classmethod
def from_harmonic_bond_params(cls, bond_idxs, params, temperature=DEFAULT_TEMP, sigma_thresh=DEFAULT_SIGMA_THRESH):
Expand Down
14 changes: 9 additions & 5 deletions timemachine/optimize/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,18 @@
Gradient-based optimization of high-dimensional protocols, using a reweighting-based estimate of a
a T.I.-tailored objective, stddev(du/dlambda).
"""
from typing import Any, Callable, cast

from typing import Callable, cast

import numpy as np
from jax import jit
from jax import Array, jit
from jax import numpy as jnp
from jax import vmap
from jax.scipy.special import logsumexp
from jax.typing import ArrayLike
from scipy.optimize import bisect

Float = float
Array = Any # see https://github.com/google/jax/issues/943
DistanceFxn = Callable[[Float, Float], Float]
WorkStddevEstimator = DistanceFxn

Expand Down Expand Up @@ -99,7 +100,7 @@ def rebalance_initial_protocol(
return optimized_protocol


def log_weights_from_mixture(u_kn: Array, f_k: Array, N_k: Array) -> Array:
def log_weights_from_mixture(u_kn: ArrayLike, f_k: ArrayLike, N_k: ArrayLike) -> Array:
r"""Assuming
* K reduced potential energy functions u_k
* N_k samples from each state e^{-u_k} / Z_k
Expand All @@ -111,6 +112,9 @@ def log_weights_from_mixture(u_kn: Array, f_k: Array, N_k: Array) -> Array:
interpret the collection of N = \sum_k N_k samples as coming from a
mixture of states p(x) = (1 / K) \sum_k e^-u_k / Z_k
"""
f_k = jnp.asarray(f_k)
u_kn = jnp.asarray(u_kn)

log_q_k = f_k - u_kn.T
N_k = np.array(N_k, dtype=np.float64) # may be ints, or in a list...
log_weights = logsumexp(log_q_k, b=N_k, axis=1)
Expand All @@ -121,7 +125,7 @@ def linear_u_kn_interpolant(lambdas: Array, u_kn: Array) -> Callable:
"""Given a matrix u_kn[k, n] = u(xs[n], lambdas[k]) produce linear interpolated estimates of u(xs[n], lam)
at arbitrary new values lam"""

def u_interp(u_n: Array, lam: Float) -> Float:
def u_interp(u_n: ArrayLike, lam: ArrayLike) -> Array:
return jnp.nan_to_num(jnp.interp(lam, lambdas, u_n), nan=+jnp.inf, posinf=+jnp.inf)

@jit
Expand Down
24 changes: 13 additions & 11 deletions timemachine/potentials/nonbonded.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import Any
from typing import Tuple, cast

import jax.numpy as jnp
import numpy as np
from jax import jit, vmap
from jax import Array, jit, vmap
from jax.scipy.special import erfc
from jax.typing import ArrayLike
from numpy.typing import NDArray
from scipy.special import binom

Expand All @@ -17,8 +18,6 @@
process_traj_in_chunks,
)

Array = Any


def switch_fn(dij, cutoff):
return jnp.power(jnp.cos((jnp.pi * jnp.power(dij, 8)) / (2 * cutoff)), 2)
Expand Down Expand Up @@ -68,7 +67,7 @@ def nonbonded_block_unsummed(
params_j: NDArray,
beta: float,
cutoff: float,
):
) -> Array:
"""
This is a modified version of `nonbonded` that computes a block of
NxM interactions between two sets of particles x_i and x_j. It is assumed that
Expand Down Expand Up @@ -129,7 +128,7 @@ def nonbonded_block_unsummed(
lj = lennard_jones(dij, sig_ij, eps_ij)

nrgs = jnp.where(dij < cutoff, es + lj, 0)
return nrgs
return cast(Array, nrgs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curiosity: What type is this casting from?

Copy link
Collaborator Author

@mcwitt mcwitt Apr 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The return type of jnp.where is Array | tuple[Array, ...], I think because when where is called with only one argument the behavior is different. In this case, we know it will return an Array. (Note that cast is a no-op at runtime; it serves as an assertion to the type checker that a value has a certain type.)



def nonbonded_block(xi, xj, box, params_i, params_j, beta, cutoff):
Expand Down Expand Up @@ -330,7 +329,7 @@ def nonbonded_on_specific_pairs(
beta: float,
cutoff: Optional[float] = None,
rescale_mask=None,
):
) -> Tuple[Array, Array]:
"""See `nonbonded` docstring for more details

Notes
Expand All @@ -341,7 +340,7 @@ def nonbonded_on_specific_pairs(
"""

if len(pairs) == 0:
return np.zeros(1), np.zeros(1)
return jnp.zeros(1), jnp.zeros(1)

inds_l, inds_r = pairs.T

Expand Down Expand Up @@ -377,7 +376,10 @@ def apply_cutoff(x):
rescale_electrostatics = rescale_mask[:, 0]
electrostatics = jnp.where(rescale_electrostatics != 0, electrostatics * rescale_electrostatics, 0)

return vdW, electrostatics
vdW_arr = cast(Array, vdW)
electrostatics_arr = cast(Array, electrostatics)

return vdW_arr, electrostatics_arr


def nonbonded_on_precomputed_pairs(
Expand Down Expand Up @@ -474,7 +476,7 @@ def validate_coulomb_cutoff(cutoff=1.0, beta=2.0, threshold=1e-2):
# TODO: avoid repetition between this and lennard-jones


def coulomb_prefactor_on_atom(x_i, x_others, q_others, box=None, beta=2.0, cutoff=jnp.inf) -> float:
def coulomb_prefactor_on_atom(x_i, x_others, q_others, box=None, beta=2.0, cutoff=jnp.inf) -> Array:
Comment on lines -477 to +479
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed with the PR comment -- may be desirable later to restore an annotation that functions like this return a scalar / are non-broadcasting (i.e. they need to be transformed by vmap or similar). But the docstring + context seems sufficient.

"""Precompute part of (sum_i q_i * q_j / d_ij * rxn_field(d_ij)) that does not depend on q_i

Parameters
Expand Down Expand Up @@ -557,7 +559,7 @@ def f_snapshot(coords, box):
return process_traj_in_chunks(f_snapshot, traj, boxes, chunk_size)


def coulomb_interaction_group_energy(q_ligand: Array, q_prefactors: Array) -> float:
def coulomb_interaction_group_energy(q_ligand: ArrayLike, q_prefactors: ArrayLike) -> Array:
"""Assuming q_prefactors = coulomb_prefactors_on_snapshot(x_ligand, ...),
cheaply compute the energy of ligand-environment interaction group

Expand Down