Skip to content

Commit

Permalink
core: add additional import mappings to loads (#26406)
Browse files Browse the repository at this point in the history
Support using additional import mapping. This allows users to override
old mappings/add new imports to the loads function.

- [x ] **Add tests and docs**: If you're adding a new integration,
please include
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in
`docs/docs/integrations` directory.


- [ x] **Lint and test**: Run `make format`, `make lint` and `make test`
from the root of the package(s) you've modified. See contribution
guidelines for more: https://python.langchain.com/docs/contributing/
  • Loading branch information
langchain-infra authored Sep 13, 2024
1 parent 1d98937 commit 8a02fd9
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 9 deletions.
49 changes: 41 additions & 8 deletions libs/core/langchain_core/load/load.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import importlib
import json
import os
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Tuple

from langchain_core._api import beta
from langchain_core.load.mapping import (
Expand Down Expand Up @@ -36,6 +36,9 @@ def __init__(
secrets_map: Optional[Dict[str, str]] = None,
valid_namespaces: Optional[List[str]] = None,
secrets_from_env: bool = True,
additional_import_mappings: Optional[
Dict[Tuple[str, ...], Tuple[str, ...]]
] = None,
) -> None:
"""Initialize the reviver.
Expand All @@ -47,15 +50,27 @@ def __init__(
to allow to be deserialized. Defaults to None.
secrets_from_env: Whether to load secrets from the environment.
Defaults to True.
additional_import_mappings: A dictionary of additional namespace mappings
You can use this to override default mappings or add new mappings.
Defaults to None.
"""
self.secrets_from_env = secrets_from_env
self.secrets_map = secrets_map or dict()
# By default only support langchain, but user can pass in additional namespaces
# By default, only support langchain, but user can pass in additional namespaces
self.valid_namespaces = (
[*DEFAULT_NAMESPACES, *valid_namespaces]
if valid_namespaces
else DEFAULT_NAMESPACES
)
self.additional_import_mappings = additional_import_mappings or dict()
self.import_mappings = (
{
**ALL_SERIALIZABLE_MAPPINGS,
**self.additional_import_mappings,
}
if self.additional_import_mappings
else ALL_SERIALIZABLE_MAPPINGS
)

def __call__(self, value: Dict[str, Any]) -> Any:
if (
Expand Down Expand Up @@ -96,16 +111,16 @@ def __call__(self, value: Dict[str, Any]) -> Any:
raise ValueError(f"Invalid namespace: {value}")

# If namespace is in known namespaces, try to use mapping
key = tuple(namespace + [name])
if namespace[0] in DEFAULT_NAMESPACES:
# Get the importable path
key = tuple(namespace + [name])
if key not in ALL_SERIALIZABLE_MAPPINGS:
if key not in self.import_mappings:
raise ValueError(
"Trying to deserialize something that cannot "
"be deserialized in current version of langchain-core: "
f"{key}"
)
import_path = ALL_SERIALIZABLE_MAPPINGS[key]
import_path = self.import_mappings[key]
# Split into module and name
import_dir, import_obj = import_path[:-1], import_path[-1]
# Import module
Expand All @@ -114,7 +129,12 @@ def __call__(self, value: Dict[str, Any]) -> Any:
cls = getattr(mod, import_obj)
# Otherwise, load by path
else:
mod = importlib.import_module(".".join(namespace))
if key in self.additional_import_mappings:
import_path = self.import_mappings[key]
mod = importlib.import_module(".".join(import_path[:-1]))
name = import_path[-1]
else:
mod = importlib.import_module(".".join(namespace))
cls = getattr(mod, name)

# The class must be a subclass of Serializable.
Expand All @@ -136,6 +156,7 @@ def loads(
secrets_map: Optional[Dict[str, str]] = None,
valid_namespaces: Optional[List[str]] = None,
secrets_from_env: bool = True,
additional_import_mappings: Optional[Dict[Tuple[str, ...], Tuple[str, ...]]] = None,
) -> Any:
"""Revive a LangChain class from a JSON string.
Equivalent to `load(json.loads(text))`.
Expand All @@ -149,12 +170,18 @@ def loads(
to allow to be deserialized. Defaults to None.
secrets_from_env: Whether to load secrets from the environment.
Defaults to True.
additional_import_mappings: A dictionary of additional namespace mappings
You can use this to override default mappings or add new mappings.
Defaults to None.
Returns:
Revived LangChain objects.
"""
return json.loads(
text, object_hook=Reviver(secrets_map, valid_namespaces, secrets_from_env)
text,
object_hook=Reviver(
secrets_map, valid_namespaces, secrets_from_env, additional_import_mappings
),
)


Expand All @@ -165,6 +192,7 @@ def load(
secrets_map: Optional[Dict[str, str]] = None,
valid_namespaces: Optional[List[str]] = None,
secrets_from_env: bool = True,
additional_import_mappings: Optional[Dict[Tuple[str, ...], Tuple[str, ...]]] = None,
) -> Any:
"""Revive a LangChain class from a JSON object. Use this if you already
have a parsed JSON object, eg. from `json.load` or `orjson.loads`.
Expand All @@ -178,11 +206,16 @@ def load(
to allow to be deserialized. Defaults to None.
secrets_from_env: Whether to load secrets from the environment.
Defaults to True.
additional_import_mappings: A dictionary of additional namespace mappings
You can use this to override default mappings or add new mappings.
Defaults to None.
Returns:
Revived LangChain objects.
"""
reviver = Reviver(secrets_map, valid_namespaces, secrets_from_env)
reviver = Reviver(
secrets_map, valid_namespaces, secrets_from_env, additional_import_mappings
)

def _load(obj: Any) -> Any:
if isinstance(obj, dict):
Expand Down
60 changes: 59 additions & 1 deletion libs/core/tests/unit_tests/load/test_serializable.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Dict

from langchain_core.load import Serializable, dumpd
from langchain_core.load import Serializable, dumpd, load
from langchain_core.load.serializable import _is_field_useful
from langchain_core.pydantic_v1 import Field

Expand Down Expand Up @@ -107,3 +107,61 @@ class Config:
foo = Foo(x=default_x, y=default_y, z=ArrayObj())
assert not _is_field_useful(foo, "x", foo.x)
assert not _is_field_useful(foo, "y", foo.y)


class Foo(Serializable):
bar: int
baz: str

@classmethod
def is_lc_serializable(cls) -> bool:
return True


def test_simple_deserialization() -> None:
foo = Foo(bar=1, baz="hello")
assert foo.lc_id() == ["tests", "unit_tests", "load", "test_serializable", "Foo"]
serialized_foo = dumpd(foo)
assert serialized_foo == {
"id": ["tests", "unit_tests", "load", "test_serializable", "Foo"],
"kwargs": {"bar": 1, "baz": "hello"},
"lc": 1,
"type": "constructor",
}
new_foo = load(serialized_foo, valid_namespaces=["tests"])
assert new_foo == foo


class Foo2(Serializable):
bar: int
baz: str

@classmethod
def is_lc_serializable(cls) -> bool:
return True


def test_simple_deserialization_with_additional_imports() -> None:
foo = Foo(bar=1, baz="hello")
assert foo.lc_id() == ["tests", "unit_tests", "load", "test_serializable", "Foo"]
serialized_foo = dumpd(foo)
assert serialized_foo == {
"id": ["tests", "unit_tests", "load", "test_serializable", "Foo"],
"kwargs": {"bar": 1, "baz": "hello"},
"lc": 1,
"type": "constructor",
}
new_foo = load(
serialized_foo,
valid_namespaces=["tests"],
additional_import_mappings={
("tests", "unit_tests", "load", "test_serializable", "Foo"): (
"tests",
"unit_tests",
"load",
"test_serializable",
"Foo2",
)
},
)
assert isinstance(new_foo, Foo2)

0 comments on commit 8a02fd9

Please sign in to comment.