Skip to content

Commit

Permalink
adapt test
Browse files Browse the repository at this point in the history
  • Loading branch information
PythonFZ committed Nov 28, 2024
1 parent 3684597 commit 593da72
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 9 deletions.
4 changes: 2 additions & 2 deletions apax/nodes/selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,12 @@ def run(self):
self.selected_ids = self.select_atoms(self.data)

@property
def atoms(self) -> list[ase.Atoms]:
def frames(self) -> list[ase.Atoms]:
"""Get a list of the selected atoms objects."""
return [atoms for i, atoms in enumerate(self.data) if i in self.selected_ids]

@property
def excluded_atoms(self) -> list[ase.Atoms]:
def excluded_frames(self) -> list[ase.Atoms]:
"""Get a list of the atoms objects that were not selected."""
return [atoms for i, atoms in enumerate(self.data) if i not in self.selected_ids]

Expand Down
10 changes: 3 additions & 7 deletions tests/nodes/test_n_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,9 @@ def test_n_train_2_model(tmp_path, get_md22_stachyose):
processing_batch_size=4,
)

prediction = ips.analysis.Prediction(data=kernel_selection.atoms, model=ensemble)
prediction = ips.analysis.Prediction(data=kernel_selection.frames, model=ensemble)
analysis = ips.analysis.PredictionMetrics(
x=kernel_selection.atoms, y=prediction.atoms
x=kernel_selection.frames, y=prediction.atoms
)

proj.repro()
Expand All @@ -110,11 +110,7 @@ def test_n_train_2_model(tmp_path, get_md22_stachyose):

assert atoms.get_potential_energy() < 0

uncertainty_selection.load()
kernel_selection.load()
md.load()

uncertainties = [x.calc.results["energy_uncertainty"] for x in md.atoms]
assert [md.atoms[np.argmax(uncertainties)]] == uncertainty_selection.atoms

assert len(kernel_selection.atoms) == selection_batch_size
assert len(kernel_selection.frames) == selection_batch_size

0 comments on commit 593da72

Please sign in to comment.