-
Notifications
You must be signed in to change notification settings - Fork 17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add additional dependencies for mypy in pre-commit #1292
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,15 @@ | ||
import jax.numpy as jnp | ||
from jax import Array | ||
from jax.typing import ArrayLike | ||
|
||
from timemachine.lib import custom_ops | ||
|
||
|
||
def fixed_to_float(v: int | jnp.uint64) -> jnp.float64: | ||
def fixed_to_float(v: ArrayLike) -> Array: | ||
"""Meant to imitate the logic of timemachine/cpp/src/fixed_point.hpp::FIXED_TO_FLOAT""" | ||
return jnp.float64(jnp.int64(jnp.uint64(v))) / custom_ops.FIXED_EXPONENT | ||
|
||
|
||
def float_to_fixed(v: jnp.float32 | float) -> jnp.uint64: | ||
def float_to_fixed(v: ArrayLike) -> Array: | ||
"""Meant to imitate the logic of timemachine/cpp/src/kernels/k_fixed_point.cuh::FLOAT_TO_FIXED""" | ||
return jnp.uint64(jnp.int64(v * custom_ops.FIXED_EXPONENT)) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,10 @@ | ||
from typing import Any | ||
from typing import Tuple, cast | ||
|
||
import jax.numpy as jnp | ||
import numpy as np | ||
from jax import jit, vmap | ||
from jax import Array, jit, vmap | ||
from jax.scipy.special import erfc | ||
from jax.typing import ArrayLike | ||
from numpy.typing import NDArray | ||
from scipy.special import binom | ||
|
||
|
@@ -17,8 +18,6 @@ | |
process_traj_in_chunks, | ||
) | ||
|
||
Array = Any | ||
|
||
|
||
def switch_fn(dij, cutoff): | ||
return jnp.power(jnp.cos((jnp.pi * jnp.power(dij, 8)) / (2 * cutoff)), 2) | ||
|
@@ -68,7 +67,7 @@ def nonbonded_block_unsummed( | |
params_j: NDArray, | ||
beta: float, | ||
cutoff: float, | ||
): | ||
) -> Array: | ||
""" | ||
This is a modified version of `nonbonded` that computes a block of | ||
NxM interactions between two sets of particles x_i and x_j. It is assumed that | ||
|
@@ -129,7 +128,7 @@ def nonbonded_block_unsummed( | |
lj = lennard_jones(dij, sig_ij, eps_ij) | ||
|
||
nrgs = jnp.where(dij < cutoff, es + lj, 0) | ||
return nrgs | ||
return cast(Array, nrgs) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. curiosity: What type is this casting from? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The return type of |
||
|
||
|
||
def nonbonded_block(xi, xj, box, params_i, params_j, beta, cutoff): | ||
|
@@ -330,7 +329,7 @@ def nonbonded_on_specific_pairs( | |
beta: float, | ||
cutoff: Optional[float] = None, | ||
rescale_mask=None, | ||
): | ||
) -> Tuple[Array, Array]: | ||
"""See `nonbonded` docstring for more details | ||
|
||
Notes | ||
|
@@ -341,7 +340,7 @@ def nonbonded_on_specific_pairs( | |
""" | ||
|
||
if len(pairs) == 0: | ||
return np.zeros(1), np.zeros(1) | ||
return jnp.zeros(1), jnp.zeros(1) | ||
|
||
inds_l, inds_r = pairs.T | ||
|
||
|
@@ -377,7 +376,10 @@ def apply_cutoff(x): | |
rescale_electrostatics = rescale_mask[:, 0] | ||
electrostatics = jnp.where(rescale_electrostatics != 0, electrostatics * rescale_electrostatics, 0) | ||
|
||
return vdW, electrostatics | ||
vdW_arr = cast(Array, vdW) | ||
electrostatics_arr = cast(Array, electrostatics) | ||
|
||
return vdW_arr, electrostatics_arr | ||
|
||
|
||
def nonbonded_on_precomputed_pairs( | ||
|
@@ -474,7 +476,7 @@ def validate_coulomb_cutoff(cutoff=1.0, beta=2.0, threshold=1e-2): | |
# TODO: avoid repetition between this and lennard-jones | ||
|
||
|
||
def coulomb_prefactor_on_atom(x_i, x_others, q_others, box=None, beta=2.0, cutoff=jnp.inf) -> float: | ||
def coulomb_prefactor_on_atom(x_i, x_others, q_others, box=None, beta=2.0, cutoff=jnp.inf) -> Array: | ||
Comment on lines
-477
to
+479
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed with the PR comment -- may be desirable later to restore an annotation that functions like this return a scalar / are non-broadcasting (i.e. they need to be transformed by vmap or similar). But the docstring + context seems sufficient. |
||
"""Precompute part of (sum_i q_i * q_j / d_ij * rxn_field(d_ij)) that does not depend on q_i | ||
|
||
Parameters | ||
|
@@ -557,7 +559,7 @@ def f_snapshot(coords, box): | |
return process_traj_in_chunks(f_snapshot, traj, boxes, chunk_size) | ||
|
||
|
||
def coulomb_interaction_group_energy(q_ligand: Array, q_prefactors: Array) -> float: | ||
def coulomb_interaction_group_energy(q_ligand: ArrayLike, q_prefactors: ArrayLike) -> Array: | ||
"""Assuming q_prefactors = coulomb_prefactors_on_snapshot(x_ligand, ...), | ||
cheaply compute the energy of ligand-environment interaction group | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These two seem more accurate but less precise