Skip to content

Commit

Permalink
Make TestWorkerState aware of prediction tags
Browse files Browse the repository at this point in the history
This adds support to TestWorkerState for prediction tags as added in #2020.  We
want to test that if we tag predictions and subscribe to those tags, the worker
still behaves as expected and we receive the events we expect to receive.

This also updates cancel to assert that send_cancel() was called on the child
worker.
  • Loading branch information
philandstuff committed Nov 25, 2024
1 parent 3e56e59 commit 2bc4710
Showing 1 changed file with 27 additions and 26 deletions.
53 changes: 27 additions & 26 deletions python/tests/server/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,7 @@ class SetupState:

@frozen
class PredictState:
tag: Optional[str]
payload: Dict[str, Any]
fut: "Future[Done]"
result: Result
Expand Down Expand Up @@ -516,16 +517,16 @@ def __init__(self):

self.worker = Worker(child=self.child, events=parent_conn)

def simulate_events(self, events, *, target=None):
def simulate_events(self, events, *, tag=None, target=None):
def _handle_event(ev):
if target:
target.handle_event(ev)
self.pending.release()

subid = self.worker.subscribe(_handle_event)
subid = self.worker.subscribe(_handle_event, tag=tag)
try:
for event in events:
self.child_events.send(event)
self.child_events.send(Envelope(event, tag=tag))
self.pending.acquire()
finally:
self.worker.unsubscribe(subid)
Expand All @@ -545,18 +546,18 @@ def setup(self):
source=st.sampled_from(["stdout", "stderr"]),
)
def simulate_setup_logs(self, state: SetupState, text: str, source: str):
events = [Envelope(Log(source=source, message=text))]
events = [Log(source=source, message=text)]
self.simulate_events(events, target=state.result)

@rule(state=consumes(setup_pending), target=setup_complete)
def simulate_setup_success(self, state: SetupState):
self.simulate_events(events=[Envelope(Done())], target=state.result)
self.simulate_events(events=[Done()], target=state.result)
return state

@rule(state=consumes(setup_pending), target=setup_complete)
def simulate_setup_failure(self, state: SetupState):
self.simulate_events(
events=[Envelope(Done(error=True, error_detail="Setup failed!"))],
events=[Done(error=True, error_detail="Setup failed!")],
target=state.result,
)
return evolve(state, error=True)
Expand All @@ -576,25 +577,26 @@ def await_setup(self, state: SetupState):
@rule(
target=predict_pending,
name=ST_NAMES,
tag=st.uuids(),
steps=st.integers(min_value=0, max_value=5),
)
def predict(self, name: str, steps: int) -> PredictState:
def predict(self, name: str, steps: int, tag: uuid.UUID) -> PredictState:
payload = {"name": name, "steps": steps}
try:
fut = self.worker.predict(payload)
fut = self.worker.predict(payload, tag=tag.hex)
except InvalidStateException:
return multiple()
else:
return PredictState(payload=payload, fut=fut, result=Result())
return PredictState(tag=tag.hex, payload=payload, fut=fut, result=Result())

@rule(
state=predict_pending,
text=st.text(),
source=st.sampled_from(["stdout", "stderr"]),
)
def simulate_predict_logs(self, state: PredictState, text: str, source: str):
events = [Envelope(Log(source=source, message=text))]
self.simulate_events(events, target=state.result)
events = [Log(source=source, message=text)]
self.simulate_events(events, tag=state.tag, target=state.result)

@rule(state=consumes(predict_pending), target=predict_complete)
def simulate_predict_success(self, state: PredictState):
Expand All @@ -604,34 +606,32 @@ def simulate_predict_success(self, state: PredictState):
name = state.payload["name"]

if steps == 1:
events.append(Envelope(PredictionOutputType(multi=False)))
events.append(Envelope(PredictionOutput(payload=f"NAME={name}")))
events.append(PredictionOutputType(multi=False))
events.append(PredictionOutput(payload=f"NAME={name}"))

elif steps > 1:
events.append(Envelope(PredictionOutputType(multi=True)))
events.append(PredictionOutputType(multi=True))
for i in range(steps):
events.append(
Envelope(PredictionOutput(payload=f"NAME={name},STEP={i+1}"))
PredictionOutput(payload=f"NAME={name},STEP={i+1}"),
)

events.append(Envelope(Done(canceled=state.canceled)))
events.append(Done(canceled=state.canceled))

self.simulate_events(events, target=state.result)
self.simulate_events(events, tag=state.tag, target=state.result)
return state

@rule(state=consumes(predict_pending), target=predict_complete)
def simulate_predict_failure(self, state: PredictState):
events = [
Envelope(
Done(
error=True,
error_detail="Kaboom!",
canceled=state.canceled,
)
)
Done(
error=True,
error_detail="Kaboom!",
canceled=state.canceled,
),
]

self.simulate_events(events, target=state.result)
self.simulate_events(events, tag=state.tag, target=state.result)
return evolve(state, error=True)

@rule(state=consumes(predict_complete))
Expand Down Expand Up @@ -669,7 +669,8 @@ def await_predict(self, state: PredictState):
state=consumes(predict_pending),
)
def cancel(self, state: PredictState):
self.worker.cancel()
self.worker.cancel(tag=state.tag)
assert self.child.cancel_sent
return evolve(state, canceled=True)

def teardown(self):
Expand Down

0 comments on commit 2bc4710

Please sign in to comment.