Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add warning for unexpected model output in batched prediction #300

Merged
merged 24 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
51a9dc7
fix: add warning unexpected output from HF model (closes #294)
grumpyp Sep 25, 2024
835776a
add: warning if batched loop return string
grumpyp Sep 25, 2024
0eb5e5d
Merge branch 'Lightning-AI:main' into main
grumpyp Sep 25, 2024
353b0fc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 25, 2024
5d577a7
move output check to unbatch_no_stream
grumpyp Sep 25, 2024
31185fb
Merge branch 'main' of https://github.com/grumpyp/LitServe
grumpyp Sep 25, 2024
874e7d6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 25, 2024
6bbc319
restore format loops.py
grumpyp Sep 25, 2024
38faf79
Merge branch 'main' of https://github.com/grumpyp/LitServe
grumpyp Sep 25, 2024
9897fa9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 25, 2024
9c86c1b
fix: lint E501 Line too long
grumpyp Sep 26, 2024
771b185
Merge branch 'main' of https://github.com/grumpyp/LitServe
grumpyp Sep 26, 2024
0164e88
Update tests/test_batch.py
grumpyp Sep 26, 2024
9c66bff
Update src/litserve/api.py
grumpyp Sep 26, 2024
9221ffb
Update test to match new warning string
grumpyp Sep 26, 2024
d27b9ab
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 26, 2024
1e62ca7
Delete whitespace in warning string test_batch
grumpyp Sep 26, 2024
84db50b
Update test_batch.py
grumpyp Sep 26, 2024
77fd337
Merge branch 'main' into main
grumpyp Sep 26, 2024
bcf03dc
Update test_batch.py
aniketmaurya Sep 26, 2024
0e8fd71
Update warning copy
aniketmaurya Sep 26, 2024
2936005
update test
aniketmaurya Sep 26, 2024
25e6484
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 26, 2024
4bbc92f
fix test
aniketmaurya Sep 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/litserve/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,13 @@ def predict(self, x, **kwargs):
pass

def _unbatch_no_stream(self, output):
if isinstance(output, str):
warnings.warn(
"The 'predict' method returned a string instead of a list of predictions. "
"When batching is enabled, 'predict' must return a list to handle multiple inputs correctly. "
"Please update the 'predict' method to return a list of predictions to avoid unexpected behavior.",
UserWarning,
)
return list(output)

def _unbatch_stream(self, output_stream):
Expand Down
16 changes: 16 additions & 0 deletions tests/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,22 @@ def test_max_batch_size_warning():
LitServer(SimpleTorchAPI(), accelerator="cpu", devices=1, timeout=2)


def test_batch_predict_string_warning():
api = ls.test_examples.SimpleBatchedAPI()
api._sanitize(2, None)
api.predict = MagicMock(return_value="This is a string")

mock_input = torch.tensor([[1.0], [2.0]])

with pytest.warns(
UserWarning,
match="When batching is enabled, 'predict' must return a list to handle multiple inputs correctly.",
):
# Simulate the behavior in run_batched_loop
y = api.predict(mock_input)
api.unbatch(y)


class FakeResponseQueue:
def put(self, *args):
raise Exception("Exit loop")
Expand Down
Loading