diff --git a/replicate/run.py b/replicate/run.py index 3b6bddb..19db492 100644 --- a/replicate/run.py +++ b/replicate/run.py @@ -15,7 +15,6 @@ from replicate.exceptions import ModelError from replicate.helpers import transform_output from replicate.model import Model -from replicate.prediction import Prediction from replicate.schema import make_schema_backwards_compatible from replicate.version import Version, Versions @@ -59,15 +58,36 @@ def run( if not version and (owner and name and version_id): version = Versions(client, model=(owner, name)).get(version_id) - if version and (iterator := _make_output_iterator(version, prediction)): - return iterator + # Currently the "Prefer: wait" interface will return a prediction with a status + # of "processing" rather than a terminal state because it returns before the + # prediction has been fully processed. If request exceeds the wait time, even if + # it is actually processing, the prediction will be in a "starting" state. + # + # We should fix this in the blocking API itself. Predictions that are done should + # be in a terminal state and predictions that are processing should be in state + # "processing". + in_terminal_state = is_blocking and prediction.status != "starting" + if not in_terminal_state: + # Return a "polling" iterator if the model has an output iterator array type. + if version and _has_output_iterator_array_type(version): + return ( + transform_output(chunk, client) + for chunk in prediction.output_iterator() + ) - if not (is_blocking and prediction.status != "starting"): prediction.wait() if prediction.status == "failed": raise ModelError(prediction) + # Return an iterator for the completed prediction when needed. + if ( + version + and _has_output_iterator_array_type(version) + and prediction.output is not None + ): + return (transform_output(chunk, client) for chunk in prediction.output) + if use_file_output: return transform_output(prediction.output, client) @@ -108,15 +128,39 @@ async def async_run( if not version and (owner and name and version_id): version = await Versions(client, model=(owner, name)).async_get(version_id) - if version and (iterator := _make_async_output_iterator(version, prediction)): - return iterator + # Currently the "Prefer: wait" interface will return a prediction with a status + # of "processing" rather than a terminal state because it returns before the + # prediction has been fully processed. If request exceeds the wait time, even if + # it is actually processing, the prediction will be in a "starting" state. + # + # We should fix this in the blocking API itself. Predictions that are done should + # be in a terminal state and predictions that are processing should be in state + # "processing". + in_terminal_state = is_blocking and prediction.status != "starting" + if not in_terminal_state: + # Return a "polling" iterator if the model has an output iterator array type. + if version and _has_output_iterator_array_type(version): + return ( + transform_output(chunk, client) + async for chunk in prediction.async_output_iterator() + ) - if not (is_blocking and prediction.status != "starting"): await prediction.async_wait() if prediction.status == "failed": raise ModelError(prediction) + # Return an iterator for completed output if the model has an output iterator array type. + if ( + version + and _has_output_iterator_array_type(version) + and prediction.output is not None + ): + return ( + transform_output(chunk, client) + async for chunk in _make_async_iterator(prediction.output) + ) + if use_file_output: return transform_output(prediction.output, client) @@ -133,22 +177,9 @@ def _has_output_iterator_array_type(version: Version) -> bool: ) -def _make_output_iterator( - version: Version, prediction: Prediction -) -> Optional[Iterator[Any]]: - if _has_output_iterator_array_type(version): - return prediction.output_iterator() - - return None - - -def _make_async_output_iterator( - version: Version, prediction: Prediction -) -> Optional[AsyncIterator[Any]]: - if _has_output_iterator_array_type(version): - return prediction.async_output_iterator() - - return None +async def _make_async_iterator(list: list) -> AsyncIterator: + for item in list: + yield item __all__: List = [] diff --git a/tests/test_run.py b/tests/test_run.py index 8eac091..beb7f6e 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -1,6 +1,6 @@ import asyncio import sys -from typing import cast +from typing import AsyncIterator, Iterator, Optional, cast import httpx import pytest @@ -48,6 +48,400 @@ async def test_run(async_flag, record_mode): assert output[0].url.startswith("https://") +@pytest.mark.asyncio +async def test_run_with_iterator(mock_replicate_api_token): + router = respx.Router(base_url="https://api.replicate.com/v1") + router.route(method="POST", path="/predictions").mock( + return_value=httpx.Response( + 201, + json=_prediction_with_status("starting"), + ) + ) + router.route(method="GET", path="/predictions/p1").mock( + side_effect=[ + httpx.Response( + 200, + json=_prediction_with_status( + "processing", + [ + "Hello, ", + ], + ), + ), + httpx.Response( + 200, + json=_prediction_with_status( + "succeeded", + [ + "Hello, ", + "world!", + ], + ), + ), + ] + ) + router.route( + method="GET", + path="/models/test/example/versions/v1", + ).mock( + return_value=httpx.Response( + 201, + json=_version_with_schema( + "p1", + { + "type": "array", + "items": { + "type": "string", + }, + "x-cog-array-type": "iterator", + }, + ), + ) + ) + + client = Client( + api_token="test-token", transport=httpx.MockTransport(router.handler) + ) + client.poll_interval = 0.001 + + stream = cast( + Iterator[str], + client.run( + "test/example:v1", + input={ + "text": "Hello, world!", + }, + wait=False, + ), + ) + + output = [chunk for chunk in stream] + assert output == ["Hello, ", "world!"] + + +@pytest.mark.asyncio +async def test_async_run_with_iterator(mock_replicate_api_token): + router = respx.Router(base_url="https://api.replicate.com/v1") + router.route(method="POST", path="/predictions").mock( + return_value=httpx.Response( + 201, + json=_prediction_with_status("starting"), + ) + ) + router.route(method="GET", path="/predictions/p1").mock( + side_effect=[ + httpx.Response( + 200, + json=_prediction_with_status( + "processing", + [ + "Hello, ", + ], + ), + ), + httpx.Response( + 200, + json=_prediction_with_status( + "succeeded", + [ + "Hello, ", + "world!", + ], + ), + ), + ] + ) + router.route( + method="GET", + path="/models/test/example/versions/v1", + ).mock( + return_value=httpx.Response( + 201, + json=_version_with_schema( + "p1", + { + "type": "array", + "items": { + "type": "string", + }, + "x-cog-array-type": "iterator", + }, + ), + ) + ) + + client = Client( + api_token="test-token", transport=httpx.MockTransport(router.handler) + ) + client.poll_interval = 0.001 + + stream = cast( + AsyncIterator[FileOutput], + await client.async_run( + "test/example:v1", + input={ + "text": "Hello, world!", + }, + wait=False, + ), + ) + + output = [chunk async for chunk in stream] + assert output == ["Hello, ", "world!"] + + +@pytest.mark.asyncio +async def test_run_blocking_with_iterator(mock_replicate_api_token): + router = respx.Router(base_url="https://api.replicate.com/v1") + router.route(method="POST", path="/predictions", headers={"Prefer": "wait"}).mock( + return_value=httpx.Response( + 201, + json=_prediction_with_status( + "processing", + [ + "Hello, ", + "world!", + ], + ), + ) + ) + router.route( + method="GET", + path="/models/test/example/versions/v1", + ).mock( + return_value=httpx.Response( + 201, + json=_version_with_schema( + "p1", + { + "type": "array", + "items": { + "type": "string", + }, + "x-cog-array-type": "iterator", + }, + ), + ) + ) + + client = Client( + api_token="test-token", transport=httpx.MockTransport(router.handler) + ) + client.poll_interval = 0.001 + + stream = cast( + Iterator[str], + client.run( + "test/example:v1", + input={ + "text": "Hello, world!", + }, + ), + ) + + assert list(stream) == ["Hello, ", "world!"] + + +@pytest.mark.asyncio +async def test_run_blocking_timeout_with_iterator(mock_replicate_api_token): + router = respx.Router(base_url="https://api.replicate.com/v1") + # Initial request times out and returns "starting" state. + router.route(method="POST", path="/predictions", headers={"Prefer": "wait"}).mock( + return_value=httpx.Response( + 201, + json=_prediction_with_status( + "starting", + ), + ) + ) + # Client should start polling for the prediction. + router.route(method="GET", path="/predictions/p1").mock( + side_effect=[ + httpx.Response( + 200, + json=_prediction_with_status( + "processing", + [ + "Hello, ", + ], + ), + ), + httpx.Response( + 200, + json=_prediction_with_status( + "succeeded", + [ + "Hello, ", + "world!", + ], + ), + ), + ] + ) + router.route( + method="GET", + path="/models/test/example/versions/v1", + ).mock( + return_value=httpx.Response( + 201, + json=_version_with_schema( + "p1", + { + "type": "array", + "items": { + "type": "string", + }, + "x-cog-array-type": "iterator", + }, + ), + ) + ) + + client = Client( + api_token="test-token", transport=httpx.MockTransport(router.handler) + ) + client.poll_interval = 0.001 + + stream = cast( + Iterator[str], + client.run( + "test/example:v1", + input={ + "text": "Hello, world!", + }, + ), + ) + + assert list(stream) == ["Hello, ", "world!"] + + +@pytest.mark.asyncio +async def test_async_run_blocking_timeout_with_iterator(mock_replicate_api_token): + router = respx.Router(base_url="https://api.replicate.com/v1") + # Initial request times out and returns "starting" state. + router.route(method="POST", path="/predictions", headers={"Prefer": "wait"}).mock( + return_value=httpx.Response( + 201, + json=_prediction_with_status( + "starting", + ), + ) + ) + # Client should start polling for the prediction. + router.route(method="GET", path="/predictions/p1").mock( + side_effect=[ + httpx.Response( + 200, + json=_prediction_with_status( + "processing", + [ + "Hello, ", + ], + ), + ), + httpx.Response( + 200, + json=_prediction_with_status( + "succeeded", + [ + "Hello, ", + "world!", + ], + ), + ), + ] + ) + router.route( + method="GET", + path="/models/test/example/versions/v1", + ).mock( + return_value=httpx.Response( + 201, + json=_version_with_schema( + "p1", + { + "type": "array", + "items": { + "type": "string", + }, + "x-cog-array-type": "iterator", + }, + ), + ) + ) + + client = Client( + api_token="test-token", transport=httpx.MockTransport(router.handler) + ) + client.poll_interval = 0.001 + + stream = cast( + AsyncIterator[str], + await client.async_run( + "test/example:v1", + input={ + "text": "Hello, world!", + }, + ), + ) + + output = [chunk async for chunk in stream] + assert output == ["Hello, ", "world!"] + + +@pytest.mark.asyncio +async def test_async_run_blocking_with_iterator(mock_replicate_api_token): + router = respx.Router(base_url="https://api.replicate.com/v1") + router.route(method="POST", path="/predictions", headers={"Prefer": "wait"}).mock( + return_value=httpx.Response( + 201, + json=_prediction_with_status( + "processing", + [ + "Hello, ", + "world!", + ], + ), + ) + ) + router.route( + method="GET", + path="/models/test/example/versions/v1", + ).mock( + return_value=httpx.Response( + 201, + json=_version_with_schema( + "p1", + { + "type": "array", + "items": { + "type": "string", + }, + "x-cog-array-type": "iterator", + }, + ), + ) + ) + + client = Client( + api_token="test-token", transport=httpx.MockTransport(router.handler) + ) + client.poll_interval = 0.001 + + stream = cast( + AsyncIterator[FileOutput], + await client.async_run( + "test/example:v1", + input={ + "text": "Hello, world!", + }, + ), + ) + + output = [chunk async for chunk in stream] + assert output == ["Hello, ", "world!"] + + @pytest.mark.vcr("run__concurrently.yaml") @pytest.mark.asyncio @pytest.mark.skipif( @@ -104,35 +498,17 @@ async def test_run_with_invalid_token(): @pytest.mark.asyncio async def test_run_version_with_invalid_cog_version(mock_replicate_api_token): - def prediction_with_status(status: str) -> dict: - return { - "id": "p1", - "model": "test/example", - "version": "v1", - "urls": { - "get": "https://api.replicate.com/v1/predictions/p1", - "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", - }, - "created_at": "2023-10-05T12:00:00.000000Z", - "source": "api", - "status": status, - "input": {"text": "world"}, - "output": "Hello, world!" if status == "succeeded" else None, - "error": None, - "logs": "", - } - router = respx.Router(base_url="https://api.replicate.com/v1") router.route(method="POST", path="/predictions").mock( return_value=httpx.Response( 201, - json=prediction_with_status("starting"), + json=_prediction_with_status("starting"), ) ) router.route(method="GET", path="/predictions/p1").mock( return_value=httpx.Response( 200, - json=prediction_with_status("succeeded"), + json=_prediction_with_status("succeeded", "Hello, world!"), ) ) router.route( @@ -141,37 +517,7 @@ def prediction_with_status(status: str) -> dict: ).mock( return_value=httpx.Response( 201, - json={ - "id": "f2d6b24e6002f25f77ae89c2b0a5987daa6d0bf751b858b94b8416e8542434d1", - "created_at": "2022-03-16T00:35:56.210272Z", - "cog_version": "dev", - "openapi_schema": { - "openapi": "3.0.2", - "info": {"title": "Cog", "version": "0.1.0"}, - "paths": {}, - "components": { - "schemas": { - "Input": { - "type": "object", - "title": "Input", - "required": ["text"], - "properties": { - "text": { - "type": "string", - "title": "Text", - "x-order": 0, - "description": "The text input", - }, - }, - }, - "Output": { - "type": "string", - "title": "Output", - }, - } - }, - }, - }, + json=_version_with_schema(), ) ) router.route(host="api.replicate.com").pass_through() @@ -193,35 +539,17 @@ def prediction_with_status(status: str) -> dict: @pytest.mark.asyncio async def test_run_with_model_error(mock_replicate_api_token): - def prediction_with_status(status: str) -> dict: - return { - "id": "p1", - "model": "test/example", - "version": "v1", - "urls": { - "get": "https://api.replicate.com/v1/predictions/p1", - "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", - }, - "created_at": "2023-10-05T12:00:00.000000Z", - "source": "api", - "status": status, - "input": {"text": "world"}, - "output": None, - "error": "OOM" if status == "failed" else None, - "logs": "", - } - router = respx.Router(base_url="https://api.replicate.com/v1") router.route(method="POST", path="/predictions").mock( return_value=httpx.Response( 201, - json=prediction_with_status("starting"), + json=_prediction_with_status("starting"), ) ) router.route(method="GET", path="/predictions/p1").mock( return_value=httpx.Response( 200, - json=prediction_with_status("failed"), + json=_prediction_with_status("failed"), ) ) router.route( @@ -230,14 +558,7 @@ def prediction_with_status(status: str) -> dict: ).mock( return_value=httpx.Response( 201, - json={ - "id": "f2d6b24e6002f25f77ae89c2b0a5987daa6d0bf751b858b94b8416e8542434d1", - "created_at": "2024-07-18T00:35:56.210272Z", - "cog_version": "0.9.10", - "openapi_schema": { - "openapi": "3.0.2", - }, - }, + json=_version_with_schema(), ) ) router.route(host="api.replicate.com").pass_through() @@ -262,37 +583,17 @@ def prediction_with_status(status: str) -> dict: @pytest.mark.asyncio async def test_run_with_file_output(mock_replicate_api_token): - def prediction_with_status( - status: str, output: str | list[str] | None = None - ) -> dict: - return { - "id": "p1", - "model": "test/example", - "version": "v1", - "urls": { - "get": "https://api.replicate.com/v1/predictions/p1", - "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", - }, - "created_at": "2023-10-05T12:00:00.000000Z", - "source": "api", - "status": status, - "input": {"text": "world"}, - "output": output, - "error": "OOM" if status == "failed" else None, - "logs": "", - } - router = respx.Router(base_url="https://api.replicate.com/v1") router.route(method="POST", path="/predictions").mock( return_value=httpx.Response( 201, - json=prediction_with_status("starting"), + json=_prediction_with_status("starting"), ) ) router.route(method="GET", path="/predictions/p1").mock( return_value=httpx.Response( 200, - json=prediction_with_status( + json=_prediction_with_status( "succeeded", "https://api.replicate.com/v1/assets/output.txt" ), ) @@ -303,14 +604,7 @@ def prediction_with_status( ).mock( return_value=httpx.Response( 201, - json={ - "id": "f2d6b24e6002f25f77ae89c2b0a5987daa6d0bf751b858b94b8416e8542434d1", - "created_at": "2024-07-18T00:35:56.210272Z", - "cog_version": "0.9.10", - "openapi_schema": { - "openapi": "3.0.2", - }, - }, + json=_version_with_schema(), ) ) router.route(method="GET", path="/assets/output.txt").mock( @@ -347,31 +641,11 @@ def prediction_with_status( @pytest.mark.asyncio async def test_run_with_file_output_blocking(mock_replicate_api_token): - def prediction_with_status( - status: str, output: str | list[str] | None = None - ) -> dict: - return { - "id": "p1", - "model": "test/example", - "version": "v1", - "urls": { - "get": "https://api.replicate.com/v1/predictions/p1", - "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", - }, - "created_at": "2023-10-05T12:00:00.000000Z", - "source": "api", - "status": status, - "input": {"text": "world"}, - "output": output, - "error": "OOM" if status == "failed" else None, - "logs": "", - } - router = respx.Router(base_url="https://api.replicate.com/v1") predictions_create_route = router.route(method="POST", path="/predictions").mock( return_value=httpx.Response( 201, - json=prediction_with_status( + json=_prediction_with_status( "processing", "data:text/plain;base64,SGVsbG8sIHdvcmxkIQ==" ), ) @@ -379,7 +653,7 @@ def prediction_with_status( predictions_get_route = router.route(method="GET", path="/predictions/p1").mock( return_value=httpx.Response( 200, - json=prediction_with_status( + json=_prediction_with_status( "succeeded", "https://api.replicate.com/v1/assets/output.txt" ), ) @@ -387,26 +661,14 @@ def prediction_with_status( router.route( method="GET", path="/models/test/example/versions/v1", - ).mock( - return_value=httpx.Response( - 201, - json={ - "id": "f2d6b24e6002f25f77ae89c2b0a5987daa6d0bf751b858b94b8416e8542434d1", - "created_at": "2024-07-18T00:35:56.210272Z", - "cog_version": "0.9.10", - "openapi_schema": { - "openapi": "3.0.2", - }, - }, - ) - ) + ).mock(return_value=httpx.Response(201, json=_version_with_schema())) client = Client( api_token="test-token", transport=httpx.MockTransport(router.handler) ) client.poll_interval = 0.001 output = cast( - list[FileOutput], + FileOutput, client.run( "test/example:v1", input={ @@ -434,37 +696,17 @@ def prediction_with_status( @pytest.mark.asyncio async def test_run_with_file_output_array(mock_replicate_api_token): - def prediction_with_status( - status: str, output: str | list[str] | None = None - ) -> dict: - return { - "id": "p1", - "model": "test/example", - "version": "v1", - "urls": { - "get": "https://api.replicate.com/v1/predictions/p1", - "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", - }, - "created_at": "2023-10-05T12:00:00.000000Z", - "source": "api", - "status": status, - "input": {"text": "world"}, - "output": output, - "error": "OOM" if status == "failed" else None, - "logs": "", - } - router = respx.Router(base_url="https://api.replicate.com/v1") router.route(method="POST", path="/predictions").mock( return_value=httpx.Response( 201, - json=prediction_with_status("starting"), + json=_prediction_with_status("starting"), ) ) router.route(method="GET", path="/predictions/p1").mock( return_value=httpx.Response( 200, - json=prediction_with_status( + json=_prediction_with_status( "succeeded", [ "https://api.replicate.com/v1/assets/hello.txt", @@ -479,14 +721,7 @@ def prediction_with_status( ).mock( return_value=httpx.Response( 201, - json={ - "id": "f2d6b24e6002f25f77ae89c2b0a5987daa6d0bf751b858b94b8416e8542434d1", - "created_at": "2024-07-18T00:35:56.210272Z", - "cog_version": "0.9.10", - "openapi_schema": { - "openapi": "3.0.2", - }, - }, + json=_version_with_schema(), ) ) router.route(method="GET", path="/assets/hello.txt").mock( @@ -521,38 +756,103 @@ def prediction_with_status( @pytest.mark.asyncio -async def test_run_with_file_output_data_uri(mock_replicate_api_token): - def prediction_with_status( - status: str, output: str | list[str] | None = None - ) -> dict: - return { - "id": "p1", - "model": "test/example", - "version": "v1", - "urls": { - "get": "https://api.replicate.com/v1/predictions/p1", - "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", +async def test_run_with_file_output_iterator(mock_replicate_api_token): + router = respx.Router(base_url="https://api.replicate.com/v1") + router.route(method="POST", path="/predictions").mock( + return_value=httpx.Response( + 201, + json=_prediction_with_status("starting"), + ) + ) + router.route(method="GET", path="/predictions/p1").mock( + side_effect=[ + httpx.Response( + 200, + json=_prediction_with_status( + "processing", + [ + "https://api.replicate.com/v1/assets/hello.txt", + ], + ), + ), + httpx.Response( + 200, + json=_prediction_with_status( + "succeeded", + [ + "https://api.replicate.com/v1/assets/hello.txt", + "https://api.replicate.com/v1/assets/world.txt", + ], + ), + ), + ] + ) + router.route( + method="GET", + path="/models/test/example/versions/v1", + ).mock( + return_value=httpx.Response( + 201, + json=_version_with_schema( + "p1", + { + "type": "array", + "items": { + "type": "string", + "format": "uri", + }, + "x-cog-array-type": "iterator", + }, + ), + ) + ) + router.route(method="GET", path="/assets/hello.txt").mock( + return_value=httpx.Response(200, content=b"Hello,") + ) + router.route(method="GET", path="/assets/world.txt").mock( + return_value=httpx.Response(200, content=b" world!") + ) + + client = Client( + api_token="test-token", transport=httpx.MockTransport(router.handler) + ) + client.poll_interval = 0.001 + + stream = cast( + Iterator[FileOutput], + client.run( + "test/example:v1", + input={ + "text": "Hello, world!", }, - "created_at": "2023-10-05T12:00:00.000000Z", - "source": "api", - "status": status, - "input": {"text": "world"}, - "output": output, - "error": "OOM" if status == "failed" else None, - "logs": "", - } + use_file_output=True, + wait=False, + ), + ) + + expected = [ + {"url": "https://api.replicate.com/v1/assets/hello.txt", "content": b"Hello,"}, + {"url": "https://api.replicate.com/v1/assets/world.txt", "content": b" world!"}, + ] + + for output, expect in zip(stream, expected): + assert output.url == expect["url"] + assert output.read() == expect["content"] + +@pytest.mark.asyncio +async def test_run_with_file_output_data_uri(mock_replicate_api_token): router = respx.Router(base_url="https://api.replicate.com/v1") router.route(method="POST", path="/predictions").mock( return_value=httpx.Response( 201, - json=prediction_with_status("starting"), + json=_prediction_with_status("starting"), ) ) router.route(method="GET", path="/predictions/p1").mock( return_value=httpx.Response( 200, - json=prediction_with_status( + json=_prediction_with_status( "succeeded", "data:text/plain;base64,SGVsbG8sIHdvcmxkIQ==", ), @@ -564,14 +864,7 @@ def prediction_with_status( ).mock( return_value=httpx.Response( 201, - json={ - "id": "f2d6b24e6002f25f77ae89c2b0a5987daa6d0bf751b858b94b8416e8542434d1", - "created_at": "2024-07-18T00:35:56.210272Z", - "cog_version": "0.9.10", - "openapi_schema": { - "openapi": "3.0.2", - }, - }, + json=_version_with_schema(), ) ) @@ -600,3 +893,57 @@ def prediction_with_status( assert await output.aread() == b"Hello, world!" async for chunk in output: assert chunk == b"Hello, world!" + + +def _prediction_with_status(status: str, output: str | list[str] | None = None) -> dict: + return { + "id": "p1", + "model": "test/example", + "version": "v1", + "urls": { + "get": "https://api.replicate.com/v1/predictions/p1", + "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", + }, + "created_at": "2023-10-05T12:00:00.000000Z", + "source": "api", + "status": status, + "input": {"text": "world"}, + "output": output, + "error": "OOM" if status == "failed" else None, + "logs": "", + } + + +def _version_with_schema(id: str = "v1", output_schema: Optional[object] = None): + return { + "id": id, + "created_at": "2022-03-16T00:35:56.210272Z", + "cog_version": "dev", + "openapi_schema": { + "openapi": "3.0.2", + "info": {"title": "Cog", "version": "0.1.0"}, + "paths": {}, + "components": { + "schemas": { + "Input": { + "type": "object", + "title": "Input", + "required": ["text"], + "properties": { + "text": { + "type": "string", + "title": "Text", + "x-order": 0, + "description": "The text input", + }, + }, + }, + "Output": output_schema + or { + "type": "string", + "title": "Output", + }, + } + }, + }, + }