Skip to content

Commit

Permalink
pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
b committed Nov 5, 2024
1 parent b1b9673 commit bc81b96
Showing 1 changed file with 19 additions and 21 deletions.
40 changes: 19 additions & 21 deletions autogen/agentchat/custom_nested_chat_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, TypeVar, Union, get_type_hints


class CustomNestedChatCondition():
def __init__(self,
func: Callable[..., bool],
state_params: Optional[dict[str, Any]]=None,
name: Optional[str] = None,
state_ttl_management: Literal["STATELESS", "STATE_KEPT_TILL_TRUE", "STATE_KEPT_TILL_FALSE"] = "STATELESS"):
class CustomNestedChatCondition:
def __init__(
self,
func: Callable[..., bool],
state_params: Optional[dict[str, Any]] = None,
name: Optional[str] = None,
state_ttl_management: Literal["STATELESS", "STATE_KEPT_TILL_TRUE", "STATE_KEPT_TILL_FALSE"] = "STATELESS",
):
"""Class to hold a user-defined function signature and its parameters to act as conditions for nested conversations.
Args:
Expand All @@ -20,32 +22,32 @@ def __init__(self,
The following 2 enum values are for scenarios where the trigger function needs to work with runtime state variables and external data that changes by itself. Should the developer need the runtime state data to work with that external data, this should provide that support
(2) When "STATE_KEPT_TILL_TRUE", the state will be kept as-is until trigger function returns true, whereby it will be reset to None.
(3) When "STATE_KEPT_TILL_FALSE", the state will be kept as-is until trigger function returns false, whereby it will be reset to None.
"""

self.func = func # Store the function itself
self.state_params = state_params # Store the parameter names as a dict
self.func = func # Store the function itself
self.state_params = state_params # Store the parameter names as a dict
if name:
self._name = name
else:
self._name = func.__name__

self.state_ttl_management=state_ttl_management
self.state_ttl_management = state_ttl_management

# ennforce func is callable that returns bool
if not callable(func):
raise ValueError(
"Function must be callable type that returns bool."
)
raise ValueError("Function must be callable type that returns bool.")
sig = inspect.signature(func)
type_hints = get_type_hints(self.func)
if (len(sig.parameters.items()) > 0):
if len(sig.parameters.items()) > 0:
self.func_has_params = True
# iterate over items in params and sig.parameters, ensure a 1:1 match between them based on key of params and name of sig.params and data type
for param_name, param in sig.parameters.items():
# Check if the parameter exists in the provided params
if param_name not in self.state_params:
raise ValueError(f"Parameter '{param_name}' is required by the function '{self.func.__name__}' but was not provided in params.")
raise ValueError(
f"Parameter '{param_name}' is required by the function '{self.func.__name__}' but was not provided in params."
)
# Ensure that type of func param is same as that of self.params counterpart
expected_type = type_hints.get(param_name, Any)
if not isinstance(self.state_params[param_name], expected_type) and expected_type is not Any:
Expand All @@ -59,9 +61,7 @@ def __init__(self,
# Check if func returns a boolean
result = func(*[None] * len(sig.parameters)) # Call func with default None args for demo
if not isinstance(result, bool):
raise TypeError(
f"The function '{func.__name__}' must return a boolean value."
)
raise TypeError(f"The function '{func.__name__}' must return a boolean value.")
except TypeError:
pass # Ignore if the function can't be called without arguments (further checking may be necessary)

Expand All @@ -71,9 +71,7 @@ def call_function(self):
"""
# Call the function using **params to match by parameter name
if self.func_has_params and self.state_params == None:
raise TypeError(
f"{self._name} is missing parameters"
)
raise TypeError(f"{self._name} is missing parameters")
return False
trigger_result = self.func(**self.state_params)

Expand Down

0 comments on commit bc81b96

Please sign in to comment.