Skip to content

Commit

Permalink
add loading chains from hub (#757)
Browse files Browse the repository at this point in the history
  • Loading branch information
hwchase17 authored Jan 27, 2023
1 parent 1b89a43 commit f273c50
Showing 1 changed file with 31 additions and 1 deletion.
32 changes: 31 additions & 1 deletion langchain/chains/loading.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
"""Functionality for loading chains."""
import json
import os
import tempfile
from pathlib import Path
from typing import Union

import requests
import yaml

from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.llms.loading import load_llm, load_llm_from_config
from langchain.prompts.loading import load_prompt, load_prompt_from_config

URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/chains/"


def _load_llm_chain(config: dict) -> LLMChain:
"""Load LLM chain from config dict."""
Expand Down Expand Up @@ -48,7 +53,16 @@ def load_chain_from_config(config: dict) -> Chain:
return chain_loader(config)


def load_chain(file: Union[str, Path]) -> Chain:
def load_chain(path: Union[str, Path]) -> Chain:
"""Unified method for loading a chain from LangChainHub or local fs."""
if isinstance(path, str) and path.startswith("lc://chains"):
path = os.path.relpath(path, "lc://chains/")
return _load_from_hub(path)
else:
return _load_chain_from_file(path)


def _load_chain_from_file(file: Union[str, Path]) -> Chain:
"""Load chain from file."""
# Convert file to Path object.
if isinstance(file, str):
Expand All @@ -66,3 +80,19 @@ def load_chain(file: Union[str, Path]) -> Chain:
raise ValueError("File type must be json or yaml")
# Load the chain from the config now.
return load_chain_from_config(config)


def _load_from_hub(path: str) -> Chain:
"""Load chain from hub."""
suffix = path.split(".")[-1]
if suffix not in {"json", "yaml"}:
raise ValueError("Unsupported file type.")
full_url = URL_BASE + path
r = requests.get(full_url)
if r.status_code != 200:
raise ValueError(f"Could not find file at {full_url}")
with tempfile.TemporaryDirectory() as tmpdirname:
file = tmpdirname + "/chain." + suffix
with open(file, "wb") as f:
f.write(r.content)
return _load_chain_from_file(file)

0 comments on commit f273c50

Please sign in to comment.