Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
nfcampos committed Sep 29, 2023
1 parent b67db8d commit 0318cdd
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 5 deletions.
8 changes: 4 additions & 4 deletions libs/langchain/langchain/schema/runnable/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2100,7 +2100,7 @@ def InputType(self) -> Any:
params = inspect.signature(func).parameters
first_param = next(iter(params.values()), None)
if first_param and first_param.annotation != inspect.Parameter.empty:
return first_param.annotation
return getattr(first_param.annotation, "__args__", (Any,))[0]
else:
return Any
except ValueError:
Expand All @@ -2112,7 +2112,7 @@ def OutputType(self) -> Type[Output]:
try:
sig = inspect.signature(func)
return (
sig.return_annotation
getattr(sig.return_annotation, "__args__", (Any,))[0]
if sig.return_annotation != inspect.Signature.empty
else Any
)
Expand Down Expand Up @@ -2162,7 +2162,7 @@ def invoke(
final += output
return final

async def atransform(
def atransform(
self,
input: AsyncIterator[Input],
config: Optional[RunnableConfig] = None,
Expand All @@ -2175,7 +2175,7 @@ async def atransform(
input, self._atransform, config, **kwargs
)

async def astream(
def astream(
self,
input: Input,
config: Optional[RunnableConfig] = None,
Expand Down
91 changes: 90 additions & 1 deletion libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,18 @@
import sys
from operator import itemgetter
from typing import Any, Dict, List, Optional, Sequence, Union, cast
from typing import (
Any,
AsyncIterator,
Dict,
Iterator,
List,
Optional,
Sequence,
Union,
cast,
)
from uuid import UUID
from langchain.schema.runnable.base import RunnableGenerator

import pytest
from freezegun import freeze_time
Expand Down Expand Up @@ -2809,3 +2820,81 @@ async def test_tool_from_runnable() -> None:
"title": "PromptInput",
"type": "object",
}


@pytest.mark.asyncio
async def test_runnable_gen() -> None:
"""Test that a generator can be used as a runnable."""

def gen(input: Iterator[Any]) -> Iterator[int]:
yield 1
yield 2
yield 3

runnable = RunnableGenerator(gen)

assert runnable.input_schema.schema() == {"title": "RunnableGeneratorInput"}
assert runnable.output_schema.schema() == {
"title": "RunnableGeneratorOutput",
"type": "integer",
}

assert runnable.invoke(None) == 6
assert list(runnable.stream(None)) == [1, 2, 3]
assert runnable.batch([None, None]) == [6, 6]

async def agen(input: AsyncIterator[Any]) -> AsyncIterator[int]:
yield 1
yield 2
yield 3

arunnable = RunnableGenerator(agen)

assert await arunnable.ainvoke(None) == 6
assert [p async for p in arunnable.astream(None)] == [1, 2, 3]
assert await arunnable.abatch([None, None]) == [6, 6]


@pytest.mark.asyncio
async def test_runnable_gen_transform() -> None:
"""Test that a generator can be used as a runnable."""

def gen_indexes(length_iter: Iterator[int]) -> Iterator[int]:
for i in range(next(length_iter)):
yield i

async def agen_indexes(length_iter: AsyncIterator[int]) -> AsyncIterator[int]:
async for length in length_iter:
for i in range(length):
yield i

def plus_one(input: Iterator[int]) -> Iterator[int]:
for i in input:
yield i + 1

async def aplus_one(input: AsyncIterator[int]) -> AsyncIterator[int]:
async for i in input:
yield i + 1

chain = RunnableGenerator(gen_indexes, agen_indexes) | plus_one
achain = RunnableGenerator(gen_indexes, agen_indexes) | aplus_one

assert chain.input_schema.schema() == {
"title": "RunnableGeneratorInput",
"type": "integer",
}
assert chain.output_schema.schema() == {
"title": "RunnableGeneratorOutput",
"type": "integer",
}
assert achain.input_schema.schema() == {
"title": "RunnableGeneratorInput",
"type": "integer",
}
assert achain.output_schema.schema() == {
"title": "RunnableGeneratorOutput",
"type": "integer",
}

assert list(chain.stream(3)) == [1, 2, 3]
assert [p async for p in achain.astream(4)] == [1, 2, 3, 4]

0 comments on commit 0318cdd

Please sign in to comment.