Skip to content

Commit

Permalink
Fix test_capp.
Browse files Browse the repository at this point in the history
  • Loading branch information
qchempku2017 committed Jan 21, 2025
1 parent 3c43ad4 commit 3ead63b
Show file tree
Hide file tree
Showing 7 changed files with 22 additions and 15 deletions.
3 changes: 2 additions & 1 deletion smol/capp/generate/groundstate/upper_bound/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,8 @@ def __init__(

self._other_constraints = other_constraints
if initial_occupancy is not None:
self.initial_occupancy = np.array(initial_occupancy, dtype=int)
# Enforce int32.
self.initial_occupancy = np.array(initial_occupancy, dtype=np.int32)
else:
self.initial_occupancy = None

Expand Down
4 changes: 2 additions & 2 deletions smol/capp/generate/groundstate/upper_bound/terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ def get_auxiliary_variable_values(
Values of auxiliary variables subjecting to auxiliary constraints,
as 0 and 1.
"""
variable_values = np.array(variable_values).astype(int)
variable_values = np.array(variable_values, dtype=int)
aux_values = np.ones(len(indices_in_auxiliary_products), dtype=int)
for i, inds in enumerate(indices_in_auxiliary_products):
aux_values[i] = np.product(variable_values[inds])
aux_values[i] = np.prod(variable_values[inds])

return aux_values.astype(int)

Expand Down
8 changes: 5 additions & 3 deletions smol/capp/generate/groundstate/upper_bound/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,13 @@ def get_occupancy_from_variables(
List of variable indices corresponding to each site index and
the species in its site space.
Returns:
np.ndarray: Encoded occupancy string.
np.ndarray[np.int32]: Encoded occupancy string.
"""
values = np.round(variable_values).astype(int)

num_sites = len(variable_indices)
occu = np.zeros(num_sites, dtype=int) - 1
# Enforce int32.
occu = np.zeros(num_sites) - 1
site_sublattice_ids = get_sublattice_indices_by_site(sublattices)

# Not considering species encoding order yet.
Expand All @@ -162,7 +163,8 @@ def get_occupancy_from_variables(
if np.any(occu < 0):
raise ValueError(f"Variables does not match given indices: {variable_indices}!")

return occu
# Enforce int32.
return occu.astype(np.int32)


def get_variable_values_from_occupancy(
Expand Down
3 changes: 2 additions & 1 deletion smol/capp/generate/special/sqs.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,8 +518,9 @@ def __init__(
self._kernel.kB = 1.0 # set kB to 1.0 units

# get a trial trace to initialize sample container trace
# Enforce int32.
_trace = self._kernel.compute_initial_trace(
np.zeros(kernels[0].ensemble.num_sites, dtype=int)
np.zeros(kernels[0].ensemble.num_sites, dtype=np.int32)
)
sample_trace = Trace(
**{
Expand Down
3 changes: 2 additions & 1 deletion tests/test_capp/test_solver/test_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,8 @@ def test_solve(simple_solver):
else:
other_states = set(permutations([0] * 4 + [1] * 4))
for other_state in other_states:
other_state = np.array(list(other_state), dtype=int)
# Enforce int32.
other_state = np.array(list(other_state), dtype=np.int32)
other_feats = simple_solver.ensemble.compute_feature_vector(other_state)
other_energy = np.dot(other_feats, simple_solver.ensemble.natural_parameters)
# allow just a tiny slack.
Expand Down
2 changes: 1 addition & 1 deletion tests/test_capp/test_solver/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,4 +246,4 @@ def test_expression_from_terms():
rand_aux_vals = get_auxiliary_variable_values(rand_vals, indices)
for ii, inds in enumerate(indices):
assert len(inds) > 1
assert np.product(rand_vals[inds]) == rand_aux_vals[ii]
assert np.prod(rand_vals[inds]) == rand_aux_vals[ii]
14 changes: 8 additions & 6 deletions tests/test_capp/test_solver/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def get_random_solver_test_occu(sublattices):
]
f_code = anion_sublattice.encoding[anion_sublattice.species.index(Species("F", -1))]

occu = np.zeros(24, dtype=int) - 1
occu = np.zeros(24, dtype=np.int32) - 1
occu[li_sites] = li_code
occu[mn2_sites] = mn2_code
occu[mn4_sites] = mn4_code
Expand All @@ -67,7 +67,7 @@ def get_random_solver_test_occu(sublattices):

assert np.all(occu >= 0)

return occu
return occu.astype(np.int32)


def get_random_variable_values(sublattices):
Expand Down Expand Up @@ -124,7 +124,7 @@ def get_random_neutral_occupancy(
rand_occu = initial_occupancy.copy()
for site_id, code in flip:
rand_occu[site_id] = code
return rand_occu
return rand_occu.astype(np.int32)


def get_random_neutral_variable_values(
Expand All @@ -139,7 +139,8 @@ def get_random_neutral_variable_values(

def validate_correlations_from_occupancy(expansion_processor, occupancy):
# Check whether our interpretation of corr function is correct.
occupancy = np.array(occupancy, dtype=int)
# Enforce int32.
occupancy = np.array(occupancy, dtype=np.int32)
space = expansion_processor.cluster_subspace
sc_matrix = expansion_processor.supercell_matrix
mappings = space.supercell_orbit_mappings(sc_matrix)
Expand Down Expand Up @@ -168,7 +169,8 @@ def validate_correlations_from_occupancy(expansion_processor, occupancy):
def validate_interactions_from_occupancy(decomposition_processor, occupancy):
# Check whether our interpretation of corr function is correct.
orbit_tensors = decomposition_processor._interaction_tensors
occupancy = np.array(occupancy, dtype=int)
# Enforce int32.
occupancy = np.array(occupancy, dtype=np.int32)
space = decomposition_processor.cluster_subspace
sc_matrix = decomposition_processor.supercell_matrix
mappings = space.supercell_orbit_mappings(sc_matrix)
Expand Down Expand Up @@ -201,6 +203,6 @@ def evaluate_correlations_from_variable_values(grouped_terms, variable_values):
if len(var_inds) == 0:
f += corr_factor
else:
f += corr_factor * np.product(variable_values[var_inds])
f += corr_factor * np.prod(variable_values[var_inds])
corr.append(f)
return np.array(corr)

0 comments on commit 3ead63b

Please sign in to comment.