From 136f75949279901ec9b31a06ce1298cfed5a6f54 Mon Sep 17 00:00:00 2001 From: mrbean <43734688+sam-h-bean@users.noreply.github.com> Date: Wed, 21 Dec 2022 23:39:07 -0500 Subject: [PATCH] Mrbean/support timeout (#398) Add support for passing in a request timeout to the API --- langchain/llms/huggingface_pipeline.py | 2 +- langchain/llms/openai.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/langchain/llms/huggingface_pipeline.py b/langchain/llms/huggingface_pipeline.py index eab476ce9f7f0..40198a9b8d6a9 100644 --- a/langchain/llms/huggingface_pipeline.py +++ b/langchain/llms/huggingface_pipeline.py @@ -37,7 +37,7 @@ class HuggingFacePipeline(LLM, BaseModel): pipe = pipeline( "text-generation", model=model, tokenizer=tokenizer, max_new_tokens=10 ) - hf = HuggingFacePipeline(pipeline=pipe + hf = HuggingFacePipeline(pipeline=pipe) """ pipeline: Any #: :meta private: diff --git a/langchain/llms/openai.py b/langchain/llms/openai.py index 0eb170242c0b2..42ca39b5c1b06 100644 --- a/langchain/llms/openai.py +++ b/langchain/llms/openai.py @@ -1,6 +1,6 @@ """Wrapper around OpenAI APIs.""" import sys -from typing import Any, Dict, Generator, List, Mapping, Optional +from typing import Any, Dict, Generator, List, Mapping, Optional, Tuple, Union from pydantic import BaseModel, Extra, Field, root_validator @@ -49,6 +49,8 @@ class BaseOpenAI(BaseLLM, BaseModel): openai_api_key: Optional[str] = None batch_size: int = 20 """Batch size to use when passing multiple documents to generate.""" + request_timeout: Optional[Union[float, Tuple[float, float]]] = None + """Timeout for requests to OpenAI completion API. Default is 600 seconds.""" class Config: """Configuration for this pydantic object.""" @@ -98,6 +100,7 @@ def _default_params(self) -> Dict[str, Any]: "presence_penalty": self.presence_penalty, "n": self.n, "best_of": self.best_of, + "request_timeout": self.request_timeout, } return {**normal_params, **self.model_kwargs}