diff --git a/libs/langchain/langchain/chains/rl_chain/__init__.py b/libs/langchain/langchain/chains/rl_chain/__init__.py index e71de1da6ccf8..6d5cfc3e29c78 100644 --- a/libs/langchain/langchain/chains/rl_chain/__init__.py +++ b/libs/langchain/langchain/chains/rl_chain/__init__.py @@ -13,7 +13,7 @@ from langchain.chains.rl_chain.pick_best_chain import PickBest -def configure_logger(): +def configure_logger() -> None: logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) ch = logging.StreamHandler() diff --git a/libs/langchain/langchain/chains/rl_chain/base.py b/libs/langchain/langchain/chains/rl_chain/base.py index 28baf898d2cda..721b7d35de932 100644 --- a/libs/langchain/langchain/chains/rl_chain/base.py +++ b/libs/langchain/langchain/chains/rl_chain/base.py @@ -3,7 +3,18 @@ import logging import os from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Union +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Generic, + List, + Optional, + Tuple, + Type, + TypeVar, + Union, +) from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain @@ -26,47 +37,47 @@ class _BasedOn: - def __init__(self, value): + def __init__(self, value: Any): self.value = value - def __str__(self): + def __str__(self) -> str: return str(self.value) __repr__ = __str__ -def BasedOn(anything): +def BasedOn(anything: Any) -> _BasedOn: return _BasedOn(anything) class _ToSelectFrom: - def __init__(self, value): + def __init__(self, value: Any): self.value = value - def __str__(self): + def __str__(self) -> str: return str(self.value) __repr__ = __str__ -def ToSelectFrom(anything): +def ToSelectFrom(anything: Any) -> _ToSelectFrom: if not isinstance(anything, list): raise ValueError("ToSelectFrom must be a list to select from") return _ToSelectFrom(anything) class _Embed: - def __init__(self, value, keep=False): + def __init__(self, value: Any, keep: bool = False): self.value = value self.keep = keep - def __str__(self): + def __str__(self) -> str: return str(self.value) __repr__ = __str__ -def Embed(anything, keep=False): +def Embed(anything: Any, keep: bool = False) -> Any: if isinstance(anything, _ToSelectFrom): return ToSelectFrom(Embed(anything.value, keep=keep)) elif isinstance(anything, _BasedOn): @@ -80,7 +91,7 @@ def Embed(anything, keep=False): return _Embed(anything, keep=keep) -def EmbedAndKeep(anything): +def EmbedAndKeep(anything: Any) -> Any: return Embed(anything, keep=True) @@ -91,7 +102,7 @@ def parse_lines(parser: "vw.TextFormatParser", input_str: str) -> List["vw.Examp return [parser.parse_line(line) for line in input_str.split("\n")] -def get_based_on_and_to_select_from(inputs: Dict[str, Any]): +def get_based_on_and_to_select_from(inputs: Dict[str, Any]) -> Tuple[Dict, Dict]: to_select_from = { k: inputs[k].value for k in inputs.keys() @@ -113,7 +124,7 @@ def get_based_on_and_to_select_from(inputs: Dict[str, Any]): return based_on, to_select_from -def prepare_inputs_for_autoembed(inputs: Dict[str, Any]): +def prepare_inputs_for_autoembed(inputs: Dict[str, Any]) -> Dict[str, Any]: """ go over all the inputs and if something is either wrapped in _ToSelectFrom or _BasedOn, and if their inner values are not already _Embed, then wrap them in EmbedAndKeep while retaining their _ToSelectFrom or _BasedOn status @@ -134,29 +145,38 @@ class Selected(ABC): pass -class Event(ABC): +TSelected = TypeVar("TSelected", bound=Selected) + + +class Event(Generic[TSelected], ABC): inputs: Dict[str, Any] - selected: Optional[Selected] + selected: Optional[TSelected] - def __init__(self, inputs: Dict[str, Any], selected: Optional[Selected] = None): + def __init__(self, inputs: Dict[str, Any], selected: Optional[TSelected] = None): self.inputs = inputs self.selected = selected +TEvent = TypeVar("TEvent", bound=Event) + + class Policy(ABC): - @abstractmethod - def predict(self, event: Event) -> Any: + def __init__(self, **kwargs: Any): pass @abstractmethod - def learn(self, event: Event): - pass + def predict(self, event: TEvent) -> Any: + ... @abstractmethod - def log(self, event: Event): - pass + def learn(self, event: TEvent) -> None: + ... + + @abstractmethod + def log(self, event: TEvent) -> None: + ... - def save(self): + def save(self) -> None: pass @@ -164,11 +184,11 @@ class VwPolicy(Policy): def __init__( self, model_repo: ModelRepository, - vw_cmd: Sequence[str], + vw_cmd: List[str], feature_embedder: Embedder, vw_logger: VwLogger, - *args, - **kwargs, + *args: Any, + **kwargs: Any, ): super().__init__(*args, **kwargs) self.model_repo = model_repo @@ -176,7 +196,7 @@ def __init__( self.feature_embedder = feature_embedder self.vw_logger = vw_logger - def predict(self, event: Event) -> Any: + def predict(self, event: TEvent) -> Any: import vowpal_wabbit_next as vw text_parser = vw.TextFormatParser(self.workspace) @@ -184,7 +204,7 @@ def predict(self, event: Event) -> Any: parse_lines(text_parser, self.feature_embedder.format(event)) ) - def learn(self, event: Event): + def learn(self, event: TEvent) -> None: import vowpal_wabbit_next as vw vw_ex = self.feature_embedder.format(event) @@ -192,19 +212,19 @@ def learn(self, event: Event): multi_ex = parse_lines(text_parser, vw_ex) self.workspace.learn_one(multi_ex) - def log(self, event: Event): + def log(self, event: TEvent) -> None: if self.vw_logger.logging_enabled(): vw_ex = self.feature_embedder.format(event) self.vw_logger.log(vw_ex) - def save(self): - self.model_repo.save() + def save(self) -> None: + self.model_repo.save(self.workspace) -class Embedder(ABC): +class Embedder(Generic[TEvent], ABC): @abstractmethod - def format(self, event: Event) -> str: - pass + def format(self, event: TEvent) -> str: + ... class SelectionScorer(ABC, BaseModel): @@ -212,11 +232,11 @@ class SelectionScorer(ABC, BaseModel): @abstractmethod def score_response(self, inputs: Dict[str, Any], llm_response: str) -> float: - pass + ... class AutoSelectionScorer(SelectionScorer, BaseModel): - llm_chain: Union[LLMChain, None] = None + llm_chain: LLMChain prompt: Union[BasePromptTemplate, None] = None scoring_criteria_template_str: Optional[str] = None @@ -243,7 +263,7 @@ def get_default_prompt() -> ChatPromptTemplate: return chat_prompt @root_validator(pre=True) - def set_prompt_and_llm_chain(cls, values): + def set_prompt_and_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]: llm = values.get("llm") prompt = values.get("prompt") scoring_criteria_template_str = values.get("scoring_criteria_template_str") @@ -275,7 +295,7 @@ def score_response(self, inputs: Dict[str, Any], llm_response: str) -> float: ) -class RLChain(Chain): +class RLChain(Chain, Generic[TEvent]): """ The `RLChain` class leverages the Vowpal Wabbit (VW) model as a learned policy for reinforcement learning. @@ -292,7 +312,7 @@ class RLChain(Chain): - model_save_dir (str, optional): Directory for saving the VW model. Default is the current directory. - reset_model (bool): If set to True, the model starts training from scratch. Default is False. - vw_cmd (List[str], optional): Command line arguments for the VW model. - - policy (VwPolicy): Policy used by the chain. + - policy (Type[VwPolicy]): Policy used by the chain. - vw_logs (Optional[Union[str, os.PathLike]]): Path for the VW logs. - metrics_step (int): Step for the metrics tracker. Default is -1. @@ -300,12 +320,24 @@ class RLChain(Chain): The class initializes the VW model using the provided arguments. If `selection_scorer` is not provided, a warning is logged, indicating that no reinforcement learning will occur unless the `update_with_delayed_score` method is called. """ # noqa: E501 + class _NoOpPolicy(Policy): + """Placeholder policy that does nothing""" + + def predict(self, event: TEvent) -> Any: + return None + + def learn(self, event: TEvent) -> None: + pass + + def log(self, event: TEvent) -> None: + pass + llm_chain: Chain output_key: str = "result" #: :meta private: prompt: BasePromptTemplate selection_scorer: Union[SelectionScorer, None] - policy: Optional[Policy] + active_policy: Policy = _NoOpPolicy() auto_embed: bool = True selected_input_key = "rl_chain_selected" selected_based_on_input_key = "rl_chain_selected_based_on" @@ -314,14 +346,14 @@ class RLChain(Chain): def __init__( self, feature_embedder: Embedder, - model_save_dir="./", - reset_model=False, - vw_cmd=None, - policy=VwPolicy, + model_save_dir: str = "./", + reset_model: bool = False, + vw_cmd: Optional[List[str]] = None, + policy: Type[Policy] = VwPolicy, vw_logs: Optional[Union[str, os.PathLike]] = None, - metrics_step=-1, - *args, - **kwargs, + metrics_step: int = -1, + *args: Any, + **kwargs: Any, ): super().__init__(*args, **kwargs) if self.selection_scorer is None: @@ -330,14 +362,17 @@ def __init__( reinforcement learning will be done in the RL chain \ unless update_with_delayed_score is called." ) - self.policy = policy( - model_repo=ModelRepository( - model_save_dir, with_history=True, reset=reset_model - ), - vw_cmd=vw_cmd or [], - feature_embedder=feature_embedder, - vw_logger=VwLogger(vw_logs), - ) + + if isinstance(self.active_policy, RLChain._NoOpPolicy): + self.active_policy = policy( + model_repo=ModelRepository( + model_save_dir, with_history=True, reset=reset_model + ), + vw_cmd=vw_cmd or [], + feature_embedder=feature_embedder, + vw_logger=VwLogger(vw_logs), + ) + self.metrics = MetricsTracker(step=metrics_step) class Config: @@ -374,29 +409,29 @@ def _validate_inputs(self, inputs: Dict[str, Any]) -> None: ) @abstractmethod - def _call_before_predict(self, inputs: Dict[str, Any]) -> Event: - pass + def _call_before_predict(self, inputs: Dict[str, Any]) -> TEvent: + ... @abstractmethod def _call_after_predict_before_llm( - self, inputs: Dict[str, Any], event: Event, prediction: Any - ) -> Tuple[Dict[str, Any], Event]: - pass + self, inputs: Dict[str, Any], event: TEvent, prediction: Any + ) -> Tuple[Dict[str, Any], TEvent]: + ... @abstractmethod def _call_after_llm_before_scoring( - self, llm_response: str, event: Event - ) -> Tuple[Dict[str, Any], Event]: - pass + self, llm_response: str, event: TEvent + ) -> Tuple[Dict[str, Any], TEvent]: + ... @abstractmethod def _call_after_scoring_before_learning( - self, event: Event, score: Optional[float] - ) -> Event: - pass + self, event: TEvent, score: Optional[float] + ) -> TEvent: + ... def update_with_delayed_score( - self, score: float, event: Event, force_score=False + self, score: float, event: TEvent, force_score: bool = False ) -> None: """ Updates the learned policy with the score provided. @@ -407,10 +442,11 @@ def update_with_delayed_score( "The selection scorer is set, and force_score was not set to True. \ Please set force_score=True to use this function." ) - self.metrics.on_feedback(score) + if self.metrics: + self.metrics.on_feedback(score) self._call_after_scoring_before_learning(event=event, score=score) - self.policy.learn(event=event) - self.policy.log(event=event) + self.active_policy.learn(event=event) + self.active_policy.log(event=event) def set_auto_embed(self, auto_embed: bool) -> None: """ @@ -422,15 +458,16 @@ def _call( self, inputs: Dict[str, Any], run_manager: Optional[CallbackManagerForChainRun] = None, - ) -> Dict[str, str]: + ) -> Dict[str, Any]: _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() if self.auto_embed: inputs = prepare_inputs_for_autoembed(inputs=inputs) - event = self._call_before_predict(inputs=inputs) - prediction = self.policy.predict(event=event) - self.metrics.on_decision() + event: TEvent = self._call_before_predict(inputs=inputs) + prediction = self.active_policy.predict(event=event) + if self.metrics: + self.metrics.on_decision() next_chain_inputs, event = self._call_after_predict_before_llm( inputs=inputs, event=event, prediction=prediction @@ -462,10 +499,11 @@ def _call( f"The selection scorer was not able to score, \ and the chain was not able to adjust to this response, error: {e}" ) - self.metrics.on_feedback(score) + if self.metrics: + self.metrics.on_feedback(score) event = self._call_after_scoring_before_learning(score=score, event=event) - self.policy.learn(event=event) - self.policy.log(event=event) + self.active_policy.learn(event=event) + self.active_policy.log(event=event) return {self.output_key: {"response": output, "selection_metadata": event}} @@ -473,7 +511,7 @@ def save_progress(self) -> None: """ This function should be called to save the state of the learned policy model. """ - self.policy.save() + self.active_policy.save() @property def _chain_type(self) -> str: @@ -489,7 +527,7 @@ def is_stringtype_instance(item: Any) -> bool: def embed_string_type( item: Union[str, _Embed], model: Any, namespace: Optional[str] = None -) -> Dict[str, str]: +) -> Dict[str, Union[str, List[str]]]: """Helper function to embed a string or an _Embed object.""" join_char = "" keep_str = "" @@ -513,9 +551,9 @@ def embed_string_type( return {namespace: keep_str + join_char.join(map(str, encoded))} -def embed_dict_type(item: Dict, model: Any) -> Dict[str, Union[str, List[str]]]: +def embed_dict_type(item: Dict, model: Any) -> Dict[str, Any]: """Helper function to embed a dictionary item.""" - inner_dict = {} + inner_dict: Dict[str, Any] = {} for ns, embed_item in item.items(): if isinstance(embed_item, list): inner_dict[ns] = [] @@ -530,7 +568,7 @@ def embed_dict_type(item: Dict, model: Any) -> Dict[str, Union[str, List[str]]]: def embed_list_type( item: list, model: Any, namespace: Optional[str] = None ) -> List[Dict[str, Union[str, List[str]]]]: - ret_list = [] + ret_list: List[Dict[str, Union[str, List[str]]]] = [] for embed_item in item: if isinstance(embed_item, dict): ret_list.append(embed_dict_type(embed_item, model)) @@ -540,9 +578,7 @@ def embed_list_type( def embed( - to_embed: Union[ - Union(str, _Embed(str)), Dict, List[Union(str, _Embed(str))], List[Dict] - ], + to_embed: Union[Union[str, _Embed], Dict, List[Union[str, _Embed]], List[Dict]], model: Any, namespace: Optional[str] = None, ) -> List[Dict[str, Union[str, List[str]]]]: diff --git a/libs/langchain/langchain/chains/rl_chain/metrics.py b/libs/langchain/langchain/chains/rl_chain/metrics.py index b7ec949c9eaa6..4d6306f776013 100644 --- a/libs/langchain/langchain/chains/rl_chain/metrics.py +++ b/libs/langchain/langchain/chains/rl_chain/metrics.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Dict, List, Optional, Union if TYPE_CHECKING: import pandas as pd @@ -6,11 +6,11 @@ class MetricsTracker: def __init__(self, step: int): - self._history = [] - self._step = step - self._i = 0 - self._num = 0 - self._denom = 0 + self._history: List[Dict[str, Union[int, float]]] = [] + self._step: int = step + self._i: int = 0 + self._num: float = 0 + self._denom: float = 0 @property def score(self) -> float: diff --git a/libs/langchain/langchain/chains/rl_chain/model_repository.py b/libs/langchain/langchain/chains/rl_chain/model_repository.py index eea866d1cf3c4..87f162df0ab77 100644 --- a/libs/langchain/langchain/chains/rl_chain/model_repository.py +++ b/libs/langchain/langchain/chains/rl_chain/model_repository.py @@ -4,7 +4,7 @@ import os import shutil from pathlib import Path -from typing import TYPE_CHECKING, Sequence, Union +from typing import TYPE_CHECKING, List, Union if TYPE_CHECKING: import vowpal_wabbit_next as vw @@ -22,7 +22,7 @@ def __init__( self.folder = Path(folder) self.model_path = self.folder / "latest.vw" self.with_history = with_history - if reset and self.has_history: + if reset and self.has_history(): logger.warning( "There is non empty history which is recommended to be cleaned up" ) @@ -44,7 +44,7 @@ def save(self, workspace: "vw.Workspace") -> None: if self.with_history: # write history shutil.copyfile(self.model_path, self.folder / f"model-{self.get_tag()}.vw") - def load(self, commandline: Sequence[str]) -> "vw.Workspace": + def load(self, commandline: List[str]) -> "vw.Workspace": import vowpal_wabbit_next as vw model_data = None diff --git a/libs/langchain/langchain/chains/rl_chain/pick_best_chain.py b/libs/langchain/langchain/chains/rl_chain/pick_best_chain.py index 6e1a1a5eff70b..fa7f18f8fb25d 100644 --- a/libs/langchain/langchain/chains/rl_chain/pick_best_chain.py +++ b/libs/langchain/langchain/chains/rl_chain/pick_best_chain.py @@ -1,12 +1,11 @@ from __future__ import annotations import logging -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Type, Union import langchain.chains.rl_chain.base as base from langchain.base_language import BaseLanguageModel from langchain.callbacks.manager import CallbackManagerForChainRun -from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.prompts import BasePromptTemplate @@ -17,7 +16,36 @@ SENTINEL = object() -class PickBestFeatureEmbedder(base.Embedder): +class PickBestSelected(base.Selected): + index: Optional[int] + probability: Optional[float] + score: Optional[float] + + def __init__( + self, + index: Optional[int] = None, + probability: Optional[float] = None, + score: Optional[float] = None, + ): + self.index = index + self.probability = probability + self.score = score + + +class PickBestEvent(base.Event[PickBestSelected]): + def __init__( + self, + inputs: Dict[str, Any], + to_select_from: Dict[str, Any], + based_on: Dict[str, Any], + selected: Optional[PickBestSelected] = None, + ): + super().__init__(inputs=inputs, selected=selected) + self.to_select_from = to_select_from + self.based_on = based_on + + +class PickBestFeatureEmbedder(base.Embedder[PickBestEvent]): """ Text Embedder class that embeds the `BasedOn` and `ToSelectFrom` inputs into a format that can be used by the learning policy @@ -25,7 +53,7 @@ class PickBestFeatureEmbedder(base.Embedder): model name (Any, optional): The type of embeddings to be used for feature representation. Defaults to BERT SentenceTransformer. """ # noqa E501 - def __init__(self, model: Optional[Any] = None, *args, **kwargs): + def __init__(self, model: Optional[Any] = None, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) if model is None: @@ -35,7 +63,7 @@ def __init__(self, model: Optional[Any] = None, *args, **kwargs): self.model = model - def format(self, event: PickBest.Event) -> str: + def format(self, event: PickBestEvent) -> str: """ Converts the `BasedOn` and `ToSelectFrom` into a format that can be used by VW """ @@ -54,9 +82,14 @@ def format(self, event: PickBest.Event) -> str: to_select_from_var_name, to_select_from = next( iter(event.to_select_from.items()), (None, None) ) + action_embs = ( - base.embed(to_select_from, self.model, to_select_from_var_name) - if event.to_select_from + ( + base.embed(to_select_from, self.model, to_select_from_var_name) + if event.to_select_from + else None + ) + if to_select_from else None ) @@ -88,7 +121,7 @@ def format(self, event: PickBest.Event) -> str: return example_string[:-1] -class PickBest(base.RLChain): +class PickBest(base.RLChain[PickBestEvent]): """ `PickBest` is a class designed to leverage the Vowpal Wabbit (VW) model for reinforcement learning with a context, with the goal of modifying the prompt before the LLM call. @@ -116,38 +149,10 @@ class PickBest(base.RLChain): feature_embedder (PickBestFeatureEmbedder, optional): Is an advanced attribute. Responsible for embedding the `BasedOn` and `ToSelectFrom` inputs. If omitted, a default embedder is utilized. """ # noqa E501 - class Selected(base.Selected): - index: Optional[int] - probability: Optional[float] - score: Optional[float] - - def __init__( - self, - index: Optional[int] = None, - probability: Optional[float] = None, - score: Optional[float] = None, - ): - self.index = index - self.probability = probability - self.score = score - - class Event(base.Event): - def __init__( - self, - inputs: Dict[str, Any], - to_select_from: Dict[str, Any], - based_on: Dict[str, Any], - selected: Optional[PickBest.Selected] = None, - ): - super().__init__(inputs=inputs, selected=selected) - self.to_select_from = to_select_from - self.based_on = based_on - def __init__( self, - feature_embedder: Optional[PickBestFeatureEmbedder] = None, - *args, - **kwargs, + *args: Any, + **kwargs: Any, ): vw_cmd = kwargs.get("vw_cmd", []) if not vw_cmd: @@ -163,14 +168,16 @@ def __init__( raise ValueError( "If vw_cmd is specified, it must include --cb_explore_adf" ) - kwargs["vw_cmd"] = vw_cmd + + feature_embedder = kwargs.get("feature_embedder", None) if not feature_embedder: feature_embedder = PickBestFeatureEmbedder() + kwargs["feature_embedder"] = feature_embedder - super().__init__(feature_embedder=feature_embedder, *args, **kwargs) + super().__init__(*args, **kwargs) - def _call_before_predict(self, inputs: Dict[str, Any]) -> PickBest.Event: + def _call_before_predict(self, inputs: Dict[str, Any]) -> PickBestEvent: context, actions = base.get_based_on_and_to_select_from(inputs=inputs) if not actions: raise ValueError( @@ -193,12 +200,15 @@ def _call_before_predict(self, inputs: Dict[str, Any]) -> PickBest.Event: to base the selected of ToSelectFrom on." ) - event = PickBest.Event(inputs=inputs, to_select_from=actions, based_on=context) + event = PickBestEvent(inputs=inputs, to_select_from=actions, based_on=context) return event def _call_after_predict_before_llm( - self, inputs: Dict[str, Any], event: Event, prediction: List[Tuple[int, float]] - ) -> Tuple[Dict[str, Any], PickBest.Event]: + self, + inputs: Dict[str, Any], + event: PickBestEvent, + prediction: List[Tuple[int, float]], + ) -> Tuple[Dict[str, Any], PickBestEvent]: import numpy as np prob_sum = sum(prob for _, prob in prediction) @@ -208,7 +218,7 @@ def _call_after_predict_before_llm( sampled_ap = prediction[sampled_index] sampled_action = sampled_ap[0] sampled_prob = sampled_ap[1] - selected = PickBest.Selected(index=sampled_action, probability=sampled_prob) + selected = PickBestSelected(index=sampled_action, probability=sampled_prob) event.selected = selected # only one key, value pair in event.to_select_from @@ -218,23 +228,29 @@ def _call_after_predict_before_llm( return next_chain_inputs, event def _call_after_llm_before_scoring( - self, llm_response: str, event: PickBest.Event - ) -> Tuple[Dict[str, Any], PickBest.Event]: + self, llm_response: str, event: PickBestEvent + ) -> Tuple[Dict[str, Any], PickBestEvent]: next_chain_inputs = event.inputs.copy() # only one key, value pair in event.to_select_from value = next(iter(event.to_select_from.values())) + v = ( + value[event.selected.index] + if event.selected + else event.to_select_from.values() + ) next_chain_inputs.update( { self.selected_based_on_input_key: str(event.based_on), - self.selected_input_key: value[event.selected.index], + self.selected_input_key: v, } ) return next_chain_inputs, event def _call_after_scoring_before_learning( - self, event: PickBest.Event, score: Optional[float] - ) -> Event: - event.selected.score = score + self, event: PickBestEvent, score: Optional[float] + ) -> PickBestEvent: + if event.selected: + event.selected.score = score return event def _call( @@ -248,33 +264,19 @@ def _call( def _chain_type(self) -> str: return "rl_chain_pick_best" - @classmethod - def from_chain( - cls, - llm_chain: Chain, - prompt: BasePromptTemplate, - selection_scorer=SENTINEL, - **kwargs: Any, - ): - if selection_scorer is SENTINEL: - selection_scorer = base.AutoSelectionScorer(llm=llm_chain.llm) - return PickBest( - llm_chain=llm_chain, - prompt=prompt, - selection_scorer=selection_scorer, - **kwargs, - ) - @classmethod def from_llm( - cls, + cls: Type[PickBest], llm: BaseLanguageModel, prompt: BasePromptTemplate, - selection_scorer=SENTINEL, + selection_scorer: Union[base.AutoSelectionScorer, object] = SENTINEL, **kwargs: Any, - ): + ) -> PickBest: llm_chain = LLMChain(llm=llm, prompt=prompt) - return PickBest.from_chain( + if selection_scorer is SENTINEL: + selection_scorer = base.AutoSelectionScorer(llm=llm_chain.llm) + + return PickBest( llm_chain=llm_chain, prompt=prompt, selection_scorer=selection_scorer, diff --git a/libs/langchain/langchain/chains/rl_chain/vw_logger.py b/libs/langchain/langchain/chains/rl_chain/vw_logger.py index 4fa471753957c..e8d2e1541f1c7 100644 --- a/libs/langchain/langchain/chains/rl_chain/vw_logger.py +++ b/libs/langchain/langchain/chains/rl_chain/vw_logger.py @@ -9,10 +9,10 @@ def __init__(self, path: Optional[Union[str, PathLike]]): if self.path: self.path.parent.mkdir(parents=True, exist_ok=True) - def log(self, vw_ex: str): + def log(self, vw_ex: str) -> None: if self.path: with open(self.path, "a") as f: f.write(f"{vw_ex}\n\n") - def logging_enabled(self): + def logging_enabled(self) -> bool: return bool(self.path) diff --git a/libs/langchain/tests/unit_tests/chains/rl_chain/test_pick_best_chain_call.py b/libs/langchain/tests/unit_tests/chains/rl_chain/test_pick_best_chain_call.py index 3fad1667d91c2..7bca6b470d88a 100644 --- a/libs/langchain/tests/unit_tests/chains/rl_chain/test_pick_best_chain_call.py +++ b/libs/langchain/tests/unit_tests/chains/rl_chain/test_pick_best_chain_call.py @@ -1,3 +1,5 @@ +from typing import Any, Dict + import pytest from test_utils import MockEncoder @@ -10,7 +12,7 @@ @pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") -def setup(): +def setup() -> tuple: _PROMPT_TEMPLATE = """This is a dummy prompt that will be ignored by the fake llm""" PROMPT = PromptTemplate(input_variables=[], template=_PROMPT_TEMPLATE) @@ -19,7 +21,7 @@ def setup(): @pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") -def test_multiple_ToSelectFrom_throws(): +def test_multiple_ToSelectFrom_throws() -> None: llm, PROMPT = setup() chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT) actions = ["0", "1", "2"] @@ -32,7 +34,7 @@ def test_multiple_ToSelectFrom_throws(): @pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") -def test_missing_basedOn_from_throws(): +def test_missing_basedOn_from_throws() -> None: llm, PROMPT = setup() chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT) actions = ["0", "1", "2"] @@ -41,7 +43,7 @@ def test_missing_basedOn_from_throws(): @pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") -def test_ToSelectFrom_not_a_list_throws(): +def test_ToSelectFrom_not_a_list_throws() -> None: llm, PROMPT = setup() chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT) actions = {"actions": ["0", "1", "2"]} @@ -53,7 +55,7 @@ def test_ToSelectFrom_not_a_list_throws(): @pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") -def test_update_with_delayed_score_with_auto_validator_throws(): +def test_update_with_delayed_score_with_auto_validator_throws() -> None: llm, PROMPT = setup() # this LLM returns a number so that the auto validator will return that auto_val_llm = FakeListChatModel(responses=["3"]) @@ -75,7 +77,7 @@ def test_update_with_delayed_score_with_auto_validator_throws(): @pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") -def test_update_with_delayed_score_force(): +def test_update_with_delayed_score_force() -> None: llm, PROMPT = setup() # this LLM returns a number so that the auto validator will return that auto_val_llm = FakeListChatModel(responses=["3"]) @@ -99,7 +101,7 @@ def test_update_with_delayed_score_force(): @pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") -def test_update_with_delayed_score(): +def test_update_with_delayed_score() -> None: llm, PROMPT = setup() chain = pick_best_chain.PickBest.from_llm( llm=llm, prompt=PROMPT, selection_scorer=None @@ -117,11 +119,11 @@ def test_update_with_delayed_score(): @pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") -def test_user_defined_scorer(): +def test_user_defined_scorer() -> None: llm, PROMPT = setup() class CustomSelectionScorer(rl_chain.SelectionScorer): - def score_response(self, inputs, llm_response: str) -> float: + def score_response(self, inputs: Dict[str, Any], llm_response: str) -> float: score = 200 return score @@ -139,7 +141,7 @@ def score_response(self, inputs, llm_response: str) -> float: @pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") -def test_default_embeddings(): +def test_default_embeddings() -> None: llm, PROMPT = setup() feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) chain = pick_best_chain.PickBest.from_llm( @@ -173,7 +175,7 @@ def test_default_embeddings(): @pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") -def test_default_embeddings_off(): +def test_default_embeddings_off() -> None: llm, PROMPT = setup() feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) chain = pick_best_chain.PickBest.from_llm( @@ -199,7 +201,7 @@ def test_default_embeddings_off(): @pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") -def test_default_embeddings_mixed_w_explicit_user_embeddings(): +def test_default_embeddings_mixed_w_explicit_user_embeddings() -> None: llm, PROMPT = setup() feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) chain = pick_best_chain.PickBest.from_llm( @@ -234,7 +236,7 @@ def test_default_embeddings_mixed_w_explicit_user_embeddings(): @pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") -def test_default_no_scorer_specified(): +def test_default_no_scorer_specified() -> None: _, PROMPT = setup() chain_llm = FakeListChatModel(responses=[100]) chain = pick_best_chain.PickBest.from_llm(llm=chain_llm, prompt=PROMPT) @@ -249,7 +251,7 @@ def test_default_no_scorer_specified(): @pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") -def test_explicitly_no_scorer(): +def test_explicitly_no_scorer() -> None: llm, PROMPT = setup() chain = pick_best_chain.PickBest.from_llm( llm=llm, prompt=PROMPT, selection_scorer=None @@ -265,7 +267,7 @@ def test_explicitly_no_scorer(): @pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") -def test_auto_scorer_with_user_defined_llm(): +def test_auto_scorer_with_user_defined_llm() -> None: llm, PROMPT = setup() scorer_llm = FakeListChatModel(responses=[300]) chain = pick_best_chain.PickBest.from_llm( @@ -284,7 +286,7 @@ def test_auto_scorer_with_user_defined_llm(): @pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") -def test_calling_chain_w_reserved_inputs_throws(): +def test_calling_chain_w_reserved_inputs_throws() -> None: llm, PROMPT = setup() chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT) with pytest.raises(ValueError): diff --git a/libs/langchain/tests/unit_tests/chains/rl_chain/test_pick_best_text_embedder.py b/libs/langchain/tests/unit_tests/chains/rl_chain/test_pick_best_text_embedder.py index d8ea85c6ebcc2..c49bacac6085c 100644 --- a/libs/langchain/tests/unit_tests/chains/rl_chain/test_pick_best_text_embedder.py +++ b/libs/langchain/tests/unit_tests/chains/rl_chain/test_pick_best_text_embedder.py @@ -8,10 +8,10 @@ @pytest.mark.requires("vowpal_wabbit_next") -def test_pickbest_textembedder_missing_context_throws(): +def test_pickbest_textembedder_missing_context_throws() -> None: feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) named_action = {"action": ["0", "1", "2"]} - event = pick_best_chain.PickBest.Event( + event = pick_best_chain.PickBestEvent( inputs={}, to_select_from=named_action, based_on={} ) with pytest.raises(ValueError): @@ -19,9 +19,9 @@ def test_pickbest_textembedder_missing_context_throws(): @pytest.mark.requires("vowpal_wabbit_next") -def test_pickbest_textembedder_missing_actions_throws(): +def test_pickbest_textembedder_missing_actions_throws() -> None: feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) - event = pick_best_chain.PickBest.Event( + event = pick_best_chain.PickBestEvent( inputs={}, to_select_from={}, based_on={"context": "context"} ) with pytest.raises(ValueError): @@ -29,11 +29,11 @@ def test_pickbest_textembedder_missing_actions_throws(): @pytest.mark.requires("vowpal_wabbit_next") -def test_pickbest_textembedder_no_label_no_emb(): +def test_pickbest_textembedder_no_label_no_emb() -> None: feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) named_actions = {"action1": ["0", "1", "2"]} expected = """shared |context context \n|action1 0 \n|action1 1 \n|action1 2 """ - event = pick_best_chain.PickBest.Event( + event = pick_best_chain.PickBestEvent( inputs={}, to_select_from=named_actions, based_on={"context": "context"} ) vw_ex_str = feature_embedder.format(event) @@ -41,12 +41,12 @@ def test_pickbest_textembedder_no_label_no_emb(): @pytest.mark.requires("vowpal_wabbit_next") -def test_pickbest_textembedder_w_label_no_score_no_emb(): +def test_pickbest_textembedder_w_label_no_score_no_emb() -> None: feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) named_actions = {"action1": ["0", "1", "2"]} expected = """shared |context context \n|action1 0 \n|action1 1 \n|action1 2 """ - selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0) - event = pick_best_chain.PickBest.Event( + selected = pick_best_chain.PickBestSelected(index=0, probability=1.0) + event = pick_best_chain.PickBestEvent( inputs={}, to_select_from=named_actions, based_on={"context": "context"}, @@ -57,14 +57,14 @@ def test_pickbest_textembedder_w_label_no_score_no_emb(): @pytest.mark.requires("vowpal_wabbit_next") -def test_pickbest_textembedder_w_full_label_no_emb(): +def test_pickbest_textembedder_w_full_label_no_emb() -> None: feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) named_actions = {"action1": ["0", "1", "2"]} expected = ( """shared |context context \n0:-0.0:1.0 |action1 0 \n|action1 1 \n|action1 2 """ ) - selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0) - event = pick_best_chain.PickBest.Event( + selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0) + event = pick_best_chain.PickBestEvent( inputs={}, to_select_from=named_actions, based_on={"context": "context"}, @@ -75,7 +75,7 @@ def test_pickbest_textembedder_w_full_label_no_emb(): @pytest.mark.requires("vowpal_wabbit_next") -def test_pickbest_textembedder_w_full_label_w_emb(): +def test_pickbest_textembedder_w_full_label_w_emb() -> None: feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) str1 = "0" str2 = "1" @@ -90,8 +90,8 @@ def test_pickbest_textembedder_w_full_label_w_emb(): named_actions = {"action1": rl_chain.Embed([str1, str2, str3])} context = {"context": rl_chain.Embed(ctx_str_1)} expected = f"""shared |context {encoded_ctx_str_1} \n0:-0.0:1.0 |action1 {encoded_str1} \n|action1 {encoded_str2} \n|action1 {encoded_str3} """ # noqa: E501 - selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0) - event = pick_best_chain.PickBest.Event( + selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0) + event = pick_best_chain.PickBestEvent( inputs={}, to_select_from=named_actions, based_on=context, selected=selected ) vw_ex_str = feature_embedder.format(event) @@ -99,7 +99,7 @@ def test_pickbest_textembedder_w_full_label_w_emb(): @pytest.mark.requires("vowpal_wabbit_next") -def test_pickbest_textembedder_w_full_label_w_embed_and_keep(): +def test_pickbest_textembedder_w_full_label_w_embed_and_keep() -> None: feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) str1 = "0" str2 = "1" @@ -114,8 +114,8 @@ def test_pickbest_textembedder_w_full_label_w_embed_and_keep(): named_actions = {"action1": rl_chain.EmbedAndKeep([str1, str2, str3])} context = {"context": rl_chain.EmbedAndKeep(ctx_str_1)} expected = f"""shared |context {ctx_str_1 + " " + encoded_ctx_str_1} \n0:-0.0:1.0 |action1 {str1 + " " + encoded_str1} \n|action1 {str2 + " " + encoded_str2} \n|action1 {str3 + " " + encoded_str3} """ # noqa: E501 - selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0) - event = pick_best_chain.PickBest.Event( + selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0) + event = pick_best_chain.PickBestEvent( inputs={}, to_select_from=named_actions, based_on=context, selected=selected ) vw_ex_str = feature_embedder.format(event) @@ -123,12 +123,12 @@ def test_pickbest_textembedder_w_full_label_w_embed_and_keep(): @pytest.mark.requires("vowpal_wabbit_next") -def test_pickbest_textembedder_more_namespaces_no_label_no_emb(): +def test_pickbest_textembedder_more_namespaces_no_label_no_emb() -> None: feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]} context = {"context1": "context1", "context2": "context2"} expected = """shared |context1 context1 |context2 context2 \n|a 0 |b 0 \n|action1 1 \n|action1 2 """ # noqa: E501 - event = pick_best_chain.PickBest.Event( + event = pick_best_chain.PickBestEvent( inputs={}, to_select_from=named_actions, based_on=context ) vw_ex_str = feature_embedder.format(event) @@ -136,13 +136,13 @@ def test_pickbest_textembedder_more_namespaces_no_label_no_emb(): @pytest.mark.requires("vowpal_wabbit_next") -def test_pickbest_textembedder_more_namespaces_w_label_no_emb(): +def test_pickbest_textembedder_more_namespaces_w_label_no_emb() -> None: feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]} context = {"context1": "context1", "context2": "context2"} expected = """shared |context1 context1 |context2 context2 \n|a 0 |b 0 \n|action1 1 \n|action1 2 """ # noqa: E501 - selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0) - event = pick_best_chain.PickBest.Event( + selected = pick_best_chain.PickBestSelected(index=0, probability=1.0) + event = pick_best_chain.PickBestEvent( inputs={}, to_select_from=named_actions, based_on=context, selected=selected ) vw_ex_str = feature_embedder.format(event) @@ -150,13 +150,13 @@ def test_pickbest_textembedder_more_namespaces_w_label_no_emb(): @pytest.mark.requires("vowpal_wabbit_next") -def test_pickbest_textembedder_more_namespaces_w_full_label_no_emb(): +def test_pickbest_textembedder_more_namespaces_w_full_label_no_emb() -> None: feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]} context = {"context1": "context1", "context2": "context2"} expected = """shared |context1 context1 |context2 context2 \n0:-0.0:1.0 |a 0 |b 0 \n|action1 1 \n|action1 2 """ # noqa: E501 - selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0) - event = pick_best_chain.PickBest.Event( + selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0) + event = pick_best_chain.PickBestEvent( inputs={}, to_select_from=named_actions, based_on=context, selected=selected ) vw_ex_str = feature_embedder.format(event) @@ -164,7 +164,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_no_emb(): @pytest.mark.requires("vowpal_wabbit_next") -def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb(): +def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb() -> None: feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) str1 = "0" @@ -186,8 +186,8 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb(): } expected = f"""shared |context1 {encoded_ctx_str_1} |context2 {encoded_ctx_str_2} \n0:-0.0:1.0 |a {encoded_str1} |b {encoded_str1} \n|action1 {encoded_str2} \n|action1 {encoded_str3} """ # noqa: E501 - selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0) - event = pick_best_chain.PickBest.Event( + selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0) + event = pick_best_chain.PickBestEvent( inputs={}, to_select_from=named_actions, based_on=context, selected=selected ) vw_ex_str = feature_embedder.format(event) @@ -195,7 +195,9 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb(): @pytest.mark.requires("vowpal_wabbit_next") -def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_keep(): +def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_keep() -> ( + None +): feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) str1 = "0" @@ -219,8 +221,8 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_kee } expected = f"""shared |context1 {ctx_str_1 + " " + encoded_ctx_str_1} |context2 {ctx_str_2 + " " + encoded_ctx_str_2} \n0:-0.0:1.0 |a {str1 + " " + encoded_str1} |b {str1 + " " + encoded_str1} \n|action1 {str2 + " " + encoded_str2} \n|action1 {str3 + " " + encoded_str3} """ # noqa: E501 - selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0) - event = pick_best_chain.PickBest.Event( + selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0) + event = pick_best_chain.PickBestEvent( inputs={}, to_select_from=named_actions, based_on=context, selected=selected ) vw_ex_str = feature_embedder.format(event) @@ -228,7 +230,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_kee @pytest.mark.requires("vowpal_wabbit_next") -def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb(): +def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb() -> None: feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) str1 = "0" @@ -253,8 +255,8 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb(): context = {"context1": ctx_str_1, "context2": rl_chain.Embed(ctx_str_2)} expected = f"""shared |context1 {ctx_str_1} |context2 {encoded_ctx_str_2} \n0:-0.0:1.0 |a {str1} |b {encoded_str1} \n|action1 {str2} \n|action1 {encoded_str3} """ # noqa: E501 - selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0) - event = pick_best_chain.PickBest.Event( + selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0) + event = pick_best_chain.PickBestEvent( inputs={}, to_select_from=named_actions, based_on=context, selected=selected ) vw_ex_str = feature_embedder.format(event) @@ -262,7 +264,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb(): @pytest.mark.requires("vowpal_wabbit_next") -def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_embed_and_keep(): +def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emakeep() -> None: feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) str1 = "0" @@ -290,8 +292,8 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_embed_and_ } expected = f"""shared |context1 {ctx_str_1} |context2 {ctx_str_2 + " " + encoded_ctx_str_2} \n0:-0.0:1.0 |a {str1} |b {str1 + " " + encoded_str1} \n|action1 {str2} \n|action1 {str3 + " " + encoded_str3} """ # noqa: E501 - selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0) - event = pick_best_chain.PickBest.Event( + selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0) + event = pick_best_chain.PickBestEvent( inputs={}, to_select_from=named_actions, based_on=context, selected=selected ) vw_ex_str = feature_embedder.format(event) @@ -299,7 +301,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_embed_and_ @pytest.mark.requires("vowpal_wabbit_next") -def test_raw_features_underscored(): +def test_raw_features_underscored() -> None: feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) str1 = "this is a long string" str1_underscored = str1.replace(" ", "_") @@ -315,7 +317,7 @@ def test_raw_features_underscored(): expected_no_embed = ( f"""shared |context {ctx_str_underscored} \n|action {str1_underscored} """ ) - event = pick_best_chain.PickBest.Event( + event = pick_best_chain.PickBestEvent( inputs={}, to_select_from=named_actions, based_on=context ) vw_ex_str = feature_embedder.format(event) @@ -325,7 +327,7 @@ def test_raw_features_underscored(): named_actions = {"action": rl_chain.Embed([str1])} context = {"context": rl_chain.Embed(ctx_str)} expected_embed = f"""shared |context {encoded_ctx_str} \n|action {encoded_str1} """ - event = pick_best_chain.PickBest.Event( + event = pick_best_chain.PickBestEvent( inputs={}, to_select_from=named_actions, based_on=context ) vw_ex_str = feature_embedder.format(event) @@ -335,7 +337,7 @@ def test_raw_features_underscored(): named_actions = {"action": rl_chain.EmbedAndKeep([str1])} context = {"context": rl_chain.EmbedAndKeep(ctx_str)} expected_embed_and_keep = f"""shared |context {ctx_str_underscored + " " + encoded_ctx_str} \n|action {str1_underscored + " " + encoded_str1} """ # noqa: E501 - event = pick_best_chain.PickBest.Event( + event = pick_best_chain.PickBestEvent( inputs={}, to_select_from=named_actions, based_on=context ) vw_ex_str = feature_embedder.format(event) diff --git a/libs/langchain/tests/unit_tests/chains/rl_chain/test_rl_chain_base_embedder.py b/libs/langchain/tests/unit_tests/chains/rl_chain/test_rl_chain_base_embedder.py index 895fa8ebb6001..bd0cc584ef117 100644 --- a/libs/langchain/tests/unit_tests/chains/rl_chain/test_rl_chain_base_embedder.py +++ b/libs/langchain/tests/unit_tests/chains/rl_chain/test_rl_chain_base_embedder.py @@ -1,3 +1,5 @@ +from typing import List, Union + import pytest from test_utils import MockEncoder @@ -7,13 +9,13 @@ @pytest.mark.requires("vowpal_wabbit_next") -def test_simple_context_str_no_emb(): +def test_simple_context_str_no_emb() -> None: expected = [{"a_namespace": "test"}] assert base.embed("test", MockEncoder(), "a_namespace") == expected @pytest.mark.requires("vowpal_wabbit_next") -def test_simple_context_str_w_emb(): +def test_simple_context_str_w_emb() -> None: str1 = "test" encoded_str1 = " ".join(char for char in str1) expected = [{"a_namespace": encoded_text + encoded_str1}] @@ -28,7 +30,7 @@ def test_simple_context_str_w_emb(): @pytest.mark.requires("vowpal_wabbit_next") -def test_simple_context_str_w_nested_emb(): +def test_simple_context_str_w_nested_emb() -> None: # nested embeddings, innermost wins str1 = "test" encoded_str1 = " ".join(char for char in str1) @@ -46,13 +48,13 @@ def test_simple_context_str_w_nested_emb(): @pytest.mark.requires("vowpal_wabbit_next") -def test_context_w_namespace_no_emb(): +def test_context_w_namespace_no_emb() -> None: expected = [{"test_namespace": "test"}] assert base.embed({"test_namespace": "test"}, MockEncoder()) == expected @pytest.mark.requires("vowpal_wabbit_next") -def test_context_w_namespace_w_emb(): +def test_context_w_namespace_w_emb() -> None: str1 = "test" encoded_str1 = " ".join(char for char in str1) expected = [{"test_namespace": encoded_text + encoded_str1}] @@ -67,7 +69,7 @@ def test_context_w_namespace_w_emb(): @pytest.mark.requires("vowpal_wabbit_next") -def test_context_w_namespace_w_emb2(): +def test_context_w_namespace_w_emb2() -> None: str1 = "test" encoded_str1 = " ".join(char for char in str1) expected = [{"test_namespace": encoded_text + encoded_str1}] @@ -82,7 +84,7 @@ def test_context_w_namespace_w_emb2(): @pytest.mark.requires("vowpal_wabbit_next") -def test_context_w_namespace_w_some_emb(): +def test_context_w_namespace_w_some_emb() -> None: str1 = "test1" str2 = "test2" encoded_str2 = " ".join(char for char in str2) @@ -111,16 +113,17 @@ def test_context_w_namespace_w_some_emb(): @pytest.mark.requires("vowpal_wabbit_next") -def test_simple_action_strlist_no_emb(): +def test_simple_action_strlist_no_emb() -> None: str1 = "test1" str2 = "test2" str3 = "test3" expected = [{"a_namespace": str1}, {"a_namespace": str2}, {"a_namespace": str3}] - assert base.embed([str1, str2, str3], MockEncoder(), "a_namespace") == expected + to_embed: List[Union[str, base._Embed]] = [str1, str2, str3] + assert base.embed(to_embed, MockEncoder(), "a_namespace") == expected @pytest.mark.requires("vowpal_wabbit_next") -def test_simple_action_strlist_w_emb(): +def test_simple_action_strlist_w_emb() -> None: str1 = "test1" str2 = "test2" str3 = "test3" @@ -148,7 +151,7 @@ def test_simple_action_strlist_w_emb(): @pytest.mark.requires("vowpal_wabbit_next") -def test_simple_action_strlist_w_some_emb(): +def test_simple_action_strlist_w_some_emb() -> None: str1 = "test1" str2 = "test2" str3 = "test3" @@ -181,7 +184,7 @@ def test_simple_action_strlist_w_some_emb(): @pytest.mark.requires("vowpal_wabbit_next") -def test_action_w_namespace_no_emb(): +def test_action_w_namespace_no_emb() -> None: str1 = "test1" str2 = "test2" str3 = "test3" @@ -204,7 +207,7 @@ def test_action_w_namespace_no_emb(): @pytest.mark.requires("vowpal_wabbit_next") -def test_action_w_namespace_w_emb(): +def test_action_w_namespace_w_emb() -> None: str1 = "test1" str2 = "test2" str3 = "test3" @@ -246,7 +249,7 @@ def test_action_w_namespace_w_emb(): @pytest.mark.requires("vowpal_wabbit_next") -def test_action_w_namespace_w_emb2(): +def test_action_w_namespace_w_emb2() -> None: str1 = "test1" str2 = "test2" str3 = "test3" @@ -292,7 +295,7 @@ def test_action_w_namespace_w_emb2(): @pytest.mark.requires("vowpal_wabbit_next") -def test_action_w_namespace_w_some_emb(): +def test_action_w_namespace_w_some_emb() -> None: str1 = "test1" str2 = "test2" str3 = "test3" @@ -333,7 +336,7 @@ def test_action_w_namespace_w_some_emb(): @pytest.mark.requires("vowpal_wabbit_next") -def test_action_w_namespace_w_emb_w_more_than_one_item_in_first_dict(): +def test_action_w_namespace_w_emb_w_more_than_one_item_in_first_dict() -> None: str1 = "test1" str2 = "test2" str3 = "test3" @@ -384,7 +387,7 @@ def test_action_w_namespace_w_emb_w_more_than_one_item_in_first_dict(): @pytest.mark.requires("vowpal_wabbit_next") -def test_one_namespace_w_list_of_features_no_emb(): +def test_one_namespace_w_list_of_features_no_emb() -> None: str1 = "test1" str2 = "test2" expected = [{"test_namespace": [str1, str2]}] @@ -392,7 +395,7 @@ def test_one_namespace_w_list_of_features_no_emb(): @pytest.mark.requires("vowpal_wabbit_next") -def test_one_namespace_w_list_of_features_w_some_emb(): +def test_one_namespace_w_list_of_features_w_some_emb() -> None: str1 = "test1" str2 = "test2" encoded_str2 = " ".join(char for char in str2) @@ -404,24 +407,24 @@ def test_one_namespace_w_list_of_features_w_some_emb(): @pytest.mark.requires("vowpal_wabbit_next") -def test_nested_list_features_throws(): +def test_nested_list_features_throws() -> None: with pytest.raises(ValueError): base.embed({"test_namespace": [[1, 2], [3, 4]]}, MockEncoder()) @pytest.mark.requires("vowpal_wabbit_next") -def test_dict_in_list_throws(): +def test_dict_in_list_throws() -> None: with pytest.raises(ValueError): base.embed({"test_namespace": [{"a": 1}, {"b": 2}]}, MockEncoder()) @pytest.mark.requires("vowpal_wabbit_next") -def test_nested_dict_throws(): +def test_nested_dict_throws() -> None: with pytest.raises(ValueError): base.embed({"test_namespace": {"a": {"b": 1}}}, MockEncoder()) @pytest.mark.requires("vowpal_wabbit_next") -def test_list_of_tuples_throws(): +def test_list_of_tuples_throws() -> None: with pytest.raises(ValueError): base.embed({"test_namespace": [("a", 1), ("b", 2)]}, MockEncoder()) diff --git a/libs/langchain/tests/unit_tests/chains/rl_chain/test_utils.py b/libs/langchain/tests/unit_tests/chains/rl_chain/test_utils.py index 6d54d20d9219f..625c37ee00029 100644 --- a/libs/langchain/tests/unit_tests/chains/rl_chain/test_utils.py +++ b/libs/langchain/tests/unit_tests/chains/rl_chain/test_utils.py @@ -1,3 +1,3 @@ class MockEncoder: - def encode(self, to_encode): + def encode(self, to_encode: str) -> str: return "[encoded]" + to_encode