diff --git a/burr/core/application.py b/burr/core/application.py index 66b5611a..b4387b45 100644 --- a/burr/core/application.py +++ b/burr/core/application.py @@ -1404,6 +1404,14 @@ def _load_from_persister(self): self.state = self.state.update(**self.default_state) self.sequence_id = None # has to start at None else: + if load_result["state"] is None: + raise ValueError( + f"Error: {self.initializer.__class__.__name__} returned {load_result} for " + f"partition_key:{self.partition_key}, app_id:{self.app_id}, " + f"sequence_id:{self.sequence_id}, " + f"but state was None! This is not allowed. Please return None in this case, or double " + f"check that persisted state can never be a None value." + ) # there was something last_position = load_result["position"] self.state = load_result["state"] diff --git a/tests/core/test_application.py b/tests/core/test_application.py index af0465c9..40f34ae6 100644 --- a/tests/core/test_application.py +++ b/tests/core/test_application.py @@ -1,7 +1,7 @@ import asyncio import collections import logging -from typing import Any, Awaitable, Callable, Dict, Generator, Tuple +from typing import Any, Awaitable, Callable, Dict, Generator, Literal, Optional, Tuple import pytest @@ -35,7 +35,7 @@ _validate_start, _validate_transitions, ) -from burr.core.persistence import DevNullPersister +from burr.core.persistence import BaseStatePersister, DevNullPersister, PersistedStateData from burr.lifecycle import ( PostRunStepHook, PostRunStepHookAsync, @@ -1670,6 +1670,57 @@ def test_application_builder_initialize_does_not_allow_state_setting(): ) +class BrokenPersister(BaseStatePersister): + """Broken persistor.""" + + def load( + self, partition_key: str, app_id: Optional[str], sequence_id: Optional[int] = None, **kwargs + ) -> Optional[PersistedStateData]: + return dict( + partition_key="key", + app_id="id", + sequence_id=0, + position="foo", + state=None, + created_at="", + status="completed", + ) + + def list_app_ids(self, partition_key: str, **kwargs) -> list[str]: + return [] + + def save( + self, + partition_key: Optional[str], + app_id: str, + sequence_id: int, + position: str, + state: State, + status: Literal["completed", "failed"], + **kwargs, + ): + return + + +def test_application_builder_initialize_raises_on_broken_persistor(): + """Persisters should return None when there is no state to be loaded and the default used.""" + with pytest.raises(ValueError, match="but state was None"): + counter_action = base_counter_action.with_name("counter") + result_action = Result("count").with_name("result") + ( + ApplicationBuilder() + .with_actions(counter_action, result_action) + .with_transitions(("counter", "result", default)) + .initialize_from( + BrokenPersister(), + resume_at_next_action=True, + default_state={}, + default_entrypoint="foo", + ) + .build() + ) + + def test_application_builder_assigns_correct_actions_with_dual_api(): counter_action = base_counter_action.with_name("counter") result_action = Result("count")