Skip to content

Commit

Permalink
enormously ugly iterating over the buffering, tell_numpy process. got…
Browse files Browse the repository at this point in the history
…ta deal with getting a variable number of responses
  • Loading branch information
jlnav committed Sep 27, 2024
1 parent 1ef5898 commit 8371d97
Showing 1 changed file with 24 additions and 10 deletions.
34 changes: 24 additions & 10 deletions libensemble/gen_classes/aposmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,23 @@ def __init__(
self.ask_idx = 0
self.last_ask = None
self.tell_buf = None
self.num_evals = 0
self.n_buffd_results = 0
self._told_initial_sample = False

def _slot_in_data(self, results):
"""Slot in libE_calc_in and trial data into corresponding array fields."""
for field in ["f", "x", "x_on_cube", "sim_id", "local_pt"]:
self.tell_buf[field] = results[field]
indexes = results["sim_id"]
fields = results.dtype.names

Check warning on line 48 in libensemble/gen_classes/aposmm.py

View check run for this annotation

Codecov / codecov/patch

libensemble/gen_classes/aposmm.py#L47-L48

Added lines #L47 - L48 were not covered by tests
for j, ind in enumerate(indexes):
for field in fields:
if np.isscalar(results[field][j]) or results.dtype[field].hasobject:
self.tell_buf[field][ind] = results[field][j]

Check warning on line 52 in libensemble/gen_classes/aposmm.py

View check run for this annotation

Codecov / codecov/patch

libensemble/gen_classes/aposmm.py#L52

Added line #L52 was not covered by tests
else:
field_size = len(results[field][j])

Check warning on line 54 in libensemble/gen_classes/aposmm.py

View check run for this annotation

Codecov / codecov/patch

libensemble/gen_classes/aposmm.py#L54

Added line #L54 was not covered by tests
if field_size == len(self.tell_buf[field][ind]):
self.tell_buf[field][ind] = results[field][j]

Check warning on line 56 in libensemble/gen_classes/aposmm.py

View check run for this annotation

Codecov / codecov/patch

libensemble/gen_classes/aposmm.py#L56

Added line #L56 was not covered by tests
else:
self.tell_buf[field][ind][:field_size] = results[field][j]

Check warning on line 58 in libensemble/gen_classes/aposmm.py

View check run for this annotation

Codecov / codecov/patch

libensemble/gen_classes/aposmm.py#L58

Added line #L58 was not covered by tests

@property
def _array_size(self):
Expand All @@ -56,12 +66,12 @@ def _array_size(self):
@property
def _enough_initial_sample(self):
"""We're typically happy with at least 90% of the initial sample."""
return self.num_evals > int(0.9 * self.gen_specs["user"]["initial_sample_size"])
return self.n_buffd_results > int(0.9 * self.gen_specs["user"]["initial_sample_size"])

Check warning on line 69 in libensemble/gen_classes/aposmm.py

View check run for this annotation

Codecov / codecov/patch

libensemble/gen_classes/aposmm.py#L69

Added line #L69 was not covered by tests

@property
def _enough_subsequent_points(self):
"""But we need to evaluate at least N points, for the N local-optimization processes."""
return self.num_evals >= self.gen_specs["user"]["max_active_runs"]
return self.n_buffd_results >= self.gen_specs["user"]["max_active_runs"]

Check warning on line 74 in libensemble/gen_classes/aposmm.py

View check run for this annotation

Codecov / codecov/patch

libensemble/gen_classes/aposmm.py#L74

Added line #L74 was not covered by tests

def ask_numpy(self, num_points: int = 0) -> npt.NDArray:
"""Request the next set of points to evaluate, as a NumPy array."""
Expand All @@ -88,20 +98,24 @@ def ask_numpy(self, num_points: int = 0) -> npt.NDArray:
return results

def tell_numpy(self, results: npt.NDArray, tag: int = EVAL_GEN_TAG) -> None:
if tag == PERSIS_STOP:
if results is None and tag == PERSIS_STOP:
super().tell_numpy(results, tag)
return
if self.num_evals == 0:
if len(results) == self._array_size: # DONT NEED TO COPY OVER IF THE INPUT ARRAY IS THE CORRECT SIZE
self._told_initial_sample = True # we definitely got an initial sample already if one matches
super().tell_numpy(results, tag)
return
if self.n_buffd_results == 0:
self.tell_buf = np.zeros(self._array_size, dtype=self.gen_specs["out"] + [("f", float)])
self._slot_in_data(results)
self.num_evals += len(results)
self.n_buffd_results += len(results)

Check warning on line 111 in libensemble/gen_classes/aposmm.py

View check run for this annotation

Codecov / codecov/patch

libensemble/gen_classes/aposmm.py#L109-L111

Added lines #L109 - L111 were not covered by tests
if not self._told_initial_sample and self._enough_initial_sample:
super().tell_numpy(self.tell_buf, tag)
self._told_initial_sample = True
self.num_evals = 0
self.n_buffd_results = 0

Check warning on line 115 in libensemble/gen_classes/aposmm.py

View check run for this annotation

Codecov / codecov/patch

libensemble/gen_classes/aposmm.py#L113-L115

Added lines #L113 - L115 were not covered by tests
elif self._told_initial_sample and self._enough_subsequent_points:
super().tell_numpy(self.tell_buf, tag)
self.num_evals = 0
self.n_buffd_results = 0

Check warning on line 118 in libensemble/gen_classes/aposmm.py

View check run for this annotation

Codecov / codecov/patch

libensemble/gen_classes/aposmm.py#L117-L118

Added lines #L117 - L118 were not covered by tests

def ask_updates(self) -> List[npt.NDArray]:
"""Request a list of NumPy arrays containing entries that have been identified as minima."""
Expand Down

0 comments on commit 8371d97

Please sign in to comment.