diff --git a/src/litserve/api.py b/src/litserve/api.py index 18c9bb7d..6c6699c1 100644 --- a/src/litserve/api.py +++ b/src/litserve/api.py @@ -63,6 +63,13 @@ def predict(self, x, **kwargs): pass def _unbatch_no_stream(self, output): + if isinstance(output, str): + warnings.warn( + "The 'predict' method returned a string instead of a list of predictions. " + "When batching is enabled, 'predict' must return a list to handle multiple inputs correctly. " + "Please update the 'predict' method to return a list of predictions to avoid unexpected behavior.", + UserWarning, + ) return list(output) def _unbatch_stream(self, output_stream): diff --git a/tests/test_batch.py b/tests/test_batch.py index 5a1d1a1f..8b0865ff 100644 --- a/tests/test_batch.py +++ b/tests/test_batch.py @@ -148,6 +148,22 @@ def test_max_batch_size_warning(): LitServer(SimpleTorchAPI(), accelerator="cpu", devices=1, timeout=2) +def test_batch_predict_string_warning(): + api = ls.test_examples.SimpleBatchedAPI() + api._sanitize(2, None) + api.predict = MagicMock(return_value="This is a string") + + mock_input = torch.tensor([[1.0], [2.0]]) + + with pytest.warns( + UserWarning, + match="When batching is enabled, 'predict' must return a list to handle multiple inputs correctly.", + ): + # Simulate the behavior in run_batched_loop + y = api.predict(mock_input) + api.unbatch(y) + + class FakeResponseQueue: def put(self, *args): raise Exception("Exit loop")