diff --git a/altk/core/llm/base.py b/altk/core/llm/base.py index 6db35cf..08ff3c3 100644 --- a/altk/core/llm/base.py +++ b/altk/core/llm/base.py @@ -175,6 +175,12 @@ def provider_class(cls) -> Type[Any]: Underlying SDK client class, e.g. openai.OpenAI or litellm.LiteLLM. """ + @abstractmethod + def get_model_id(self) -> str|None: + """ + returns the id of the model. eg: "gpt-5.1", "meta-llama/llama-4-maverick-17b-128e-instruct-fp8", and "claude-4-sonnet" + """ + @abstractmethod def _register_methods(self) -> None: """ diff --git a/altk/core/llm/providers/auto_from_env/auto_from_env.py b/altk/core/llm/providers/auto_from_env/auto_from_env.py index 8d622c5..ddcaa53 100644 --- a/altk/core/llm/providers/auto_from_env/auto_from_env.py +++ b/altk/core/llm/providers/auto_from_env/auto_from_env.py @@ -46,6 +46,10 @@ def __init__(self) -> None: def provider_class(cls) -> Type[Any]: raise NotImplementedError + def get_model_id(self) -> str|None: + if self._chosen_provider: + return self._chosen_provider.get_model_id() + def _register_methods(self) -> None: if self._chosen_provider: self._chosen_provider._register_methods() diff --git a/altk/core/llm/providers/ibm_watsonx_ai/ibm_watsonx_ai.py b/altk/core/llm/providers/ibm_watsonx_ai/ibm_watsonx_ai.py index 6a9bd5a..6237c7a 100644 --- a/altk/core/llm/providers/ibm_watsonx_ai/ibm_watsonx_ai.py +++ b/altk/core/llm/providers/ibm_watsonx_ai/ibm_watsonx_ai.py @@ -127,6 +127,9 @@ def provider_class(cls) -> Type[Any]: """ return ModelInference # type: ignore + def get_model_id(self) -> str: + return self.model_name # type: ignore + def _register_methods(self) -> None: """ Register how to call watsonx methods: @@ -471,6 +474,9 @@ def provider_class(cls) -> Type[Any]: """ return ModelInference # type: ignore + def get_model_id(self) -> str: + return self.model_name # type: ignore + def _register_methods(self) -> None: """ Register how to call watsonx methods for validation: diff --git a/altk/core/llm/providers/litellm/litellm.py b/altk/core/llm/providers/litellm/litellm.py index 7ec917f..5dbeaa8 100644 --- a/altk/core/llm/providers/litellm/litellm.py +++ b/altk/core/llm/providers/litellm/litellm.py @@ -35,6 +35,9 @@ def __init__( @classmethod def provider_class(cls) -> type: return litellm # type: ignore + + def get_model_id(self) -> str: + return self.model_path def _register_methods(self) -> None: """Register LiteLLM methods - only chat and chat_async are supported""" @@ -302,6 +305,9 @@ def provider_class(cls) -> Type[Any]: Must be callable with no arguments (per LLMClient __init__ logic). """ return litellm # type: ignore + + def get_model_id(self) -> str: + return self.model_path def _register_methods(self) -> None: """ diff --git a/altk/core/llm/providers/openai/openai.py b/altk/core/llm/providers/openai/openai.py index 765e5c1..dc86e18 100644 --- a/altk/core/llm/providers/openai/openai.py +++ b/altk/core/llm/providers/openai/openai.py @@ -77,6 +77,9 @@ def transform_min_tokens(value: Any, mode: Any) -> dict[str, Any]: ) self._parameter_mapper.set_custom_transform("min_tokens", transform_min_tokens) + def get_model_id(self) -> str|None: + if self._other_kwargs: + return self._other_kwargs.get("model") class BaseValidatingOpenAIClient(ValidatingLLMClient): """Base class for validating OpenAI and Azure OpenAI clients with shared parameter mapping""" @@ -146,6 +149,10 @@ def transform_min_tokens(value: Any, mode: Any) -> dict[str, Any]: ) self._parameter_mapper.set_custom_transform("min_tokens", transform_min_tokens) + def get_model_id(self) -> str|None: + if self._other_kwargs: + return self._other_kwargs.get("model") + @register_llm("openai.sync") class SyncOpenAIClient(BaseOpenAIClient, BaseLLMClient): diff --git a/altk/pre_llm/core/types.py b/altk/pre_llm/core/types.py index 33c280f..f06c9f5 100644 --- a/altk/pre_llm/core/types.py +++ b/altk/pre_llm/core/types.py @@ -125,4 +125,4 @@ def get_topics( n_results: int = 10, query_kwargs: Dict[str, Any] | None = None, distance_threshold: float | None = None, - ) -> List[RetrievedTopic]: ... + ) -> List[RetrievedTopic]: ... \ No newline at end of file diff --git a/altk/pre_tool/core/types.py b/altk/pre_tool/core/types.py index 55f404b..7abcbe6 100644 --- a/altk/pre_tool/core/types.py +++ b/altk/pre_tool/core/types.py @@ -90,44 +90,4 @@ class SPARCReflectionRunOutput(PreToolReflectionRunOutput): output: SPARCReflectionRunOutputSchema = Field( default_factory=lambda: SPARCReflectionRunOutputSchema() - ) - - -class ToolGuardBuildInputMetaData(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed=True) - policy_text: str = Field(description="Text of the policy document file") - short1: bool = Field(default=True, description="Run build short or long version. ") - validating_llm_client: LLMClient = Field( - description="ValidatingLLMClient for build time" - ) - - -class ToolGuardBuildInput(ComponentInput): - metadata: ToolGuardBuildInputMetaData = Field( - default_factory=lambda: ToolGuardBuildInputMetaData() - ) - - -class ToolGuardRunInputMetaData(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed=True) - tool_name: str = Field(description="Tool name") - tool_parms: dict = Field(default={}, description="Tool parameters") - llm_client: LLMClient = Field(description="LLMClient for build time") - - -class ToolGuardRunInput(ComponentInput): - metadata: ToolGuardRunInputMetaData = Field( - default_factory=lambda: ToolGuardRunInputMetaData() - ) - - -class ToolGuardRunOutputMetaData(BaseModel): - error_message: Union[str, bool] = Field( - description="Error string or False if no error occurred" - ) - - -class ToolGuardRunOutput(ComponentOutput): - output: ToolGuardRunOutputMetaData = Field( - default_factory=lambda: ToolGuardRunOutputMetaData() - ) + ) \ No newline at end of file diff --git a/altk/pre_tool/examples/calculator_example/__init__.py b/altk/pre_tool/examples/calculator_example/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/altk/pre_tool/examples/calculator_example/example_tools.py b/altk/pre_tool/examples/calculator_example/example_tools.py deleted file mode 100644 index a028d50..0000000 --- a/altk/pre_tool/examples/calculator_example/example_tools.py +++ /dev/null @@ -1,81 +0,0 @@ -from langchain_core.tools import tool - - -@tool -def add_tool(a: float, b: float) -> float: - """ - Add two numbers. - - Parameters - ---------- - a : float - The first number to add. - b : float - The second number to add. - - Returns - ------- - float - The sum of a and b. - """ - return a + b - - -@tool -def subtract_tool(c: float, d: float) -> float: - """ - Subtract one number from another. - - Parameters - ---------- - c : float - The number to subtract from. - d : float - The number to subtract. - - Returns - ------- - float - The result of a minus b. - """ - return c - d - - -@tool -def multiply_tool(e: float, f: float) -> float: - """ - Multiply two numbers. - - Parameters - ---------- - e : float - The first number. - f : float - The second number. - - Returns - ------- - float - The product of a and b. - """ - return e * f - - -@tool -def divide_tool(g: float, h: float) -> float: - """ - Divide one number by another. - - Parameters - ---------- - g : float - The dividend. - h : float - The divisor (must not be zero). - - Returns - ------- - float - The result of a divided by b. - """ - return g / h diff --git a/altk/pre_tool/examples/calculator_example/policy_document.md b/altk/pre_tool/examples/calculator_example/policy_document.md deleted file mode 100644 index 556136b..0000000 --- a/altk/pre_tool/examples/calculator_example/policy_document.md +++ /dev/null @@ -1,30 +0,0 @@ -# Calculator Usage Policy - -This document outlines the rules and constraints that govern the behavior and usage of the calculator application. - -## General Principles - -- The calculator should perform accurate and reliable computations. -- All operations must conform to mathematical standards and avoid undefined behavior. - -## Supported Operations - -The calculator supports the following operations: - -- Addition (`+`) -- Subtraction (`-`) -- Multiplication (`*`) -- Division (`/`) - -## Operation Constraints - -- **Division by Zero is Not Allowed** - The calculator **must not** allow division by zero. - If a user attempts to divide by zero, the operation must be rejected and an appropriate error message should be shown - (e.g., `"Error: Division by zero is not allowed."`). - -- **Summing Numbers Whose Product is 365 is Not Allowed** - The calculator **must not** allow addition of two or more numbers if their multiplication result equals `365`. - For example, adding `5 + 73` should be disallowed, because `5 * 73 = 365` . - In such cases, the operation must be rejected with an error like: - `"Error: Addition of numbers whose product equals 365 is not allowed."` diff --git a/altk/pre_tool/examples/calculator_example/run_example.py b/altk/pre_tool/examples/calculator_example/run_example.py deleted file mode 100644 index 70291a1..0000000 --- a/altk/pre_tool/examples/calculator_example/run_example.py +++ /dev/null @@ -1,64 +0,0 @@ -import os -from pathlib import Path - -import markdown -from altk.core.llm import get_llm - -from examples.calculator_example.example_tools import ( - add_tool, - subtract_tool, - multiply_tool, - divide_tool, -) -from examples.tool_guard_example import ToolGuardExample - -subdir_name = "work_dir_wx" -work_dir = Path.cwd() / subdir_name -policy_doc_path = os.path.join(str(Path.cwd()), "policy_document.md") -work_dir.mkdir(exist_ok=True) - -OPENAILiteLLMClientOutputVal = get_llm("litellm.output_val") -validating_llm_client = OPENAILiteLLMClientOutputVal( - model_path="gpt-4o-2024-08-06", - custom_llm_provider="azure", -) - -OPENAILiteLLMClient = get_llm("litellm") -llm_client = OPENAILiteLLMClient( - model_path="gpt-4o-2024-08-06", - custom_llm_provider="azure", -) -tool_funcs = [add_tool, subtract_tool, multiply_tool, divide_tool] -policy_text = open(policy_doc_path, "r", encoding="utf-8").read() -policy_text = markdown.markdown(policy_text) - -tool_guard_example = ToolGuardExample( - tools=tool_funcs, - workdir=work_dir, - policy_text=policy_text, - validating_llm_client=validating_llm_client, -) -run_output = tool_guard_example.run_example( - "Can you please calculate how much is 3/4?", - "divide_tool", - {"g": 3, "h": 4}, - llm_client, -) -print(run_output) -passed = not run_output.output.error_message -if passed: - print("success!") -else: - print("failure!") -run_output = tool_guard_example.run_example( - "Can you please calculate how much is 5/0?", - "divide_tool", - {"g": 5, "h": 0}, - llm_client, -) -print(run_output) -passed = not run_output.output.error_message -if not passed: - print("success!") -else: - print("failure!") diff --git a/altk/pre_tool/examples/tool_guard_example.py b/altk/pre_tool/examples/tool_guard_example.py deleted file mode 100644 index f0d0d15..0000000 --- a/altk/pre_tool/examples/tool_guard_example.py +++ /dev/null @@ -1,51 +0,0 @@ -import dotenv - -from langchain_core.messages import HumanMessage - -from altk.pre_tool.core.types import ( - ToolGuardBuildInputMetaData, - ToolGuardBuildInput, - ToolGuardRunInputMetaData, - ToolGuardRunInput, -) -from altk.pre_tool.pre_tool_guard import PreToolGuardComponent - -# Load environment variables -dotenv.load_dotenv() - - -class ToolGuardExample: - """ - Runs examples with a ToolGuard component and validates tool invocation against policy. - """ - - def __init__(self, tools, workdir, policy_text, validating_llm_client, short=True): - self.tools = tools - self.middleware = PreToolGuardComponent(tools=tools, workdir=workdir) - - build_input = ToolGuardBuildInput( - metadata=ToolGuardBuildInputMetaData( - policy_text=policy_text, - short1=short, - validating_llm_client=validating_llm_client, - ) - ) - self.middleware._build(build_input) - - def run_example( - self, user_message: str, tool_name: str, tool_params: dict, llm_client - ): - """ - Runs a single example through ToolGuard and checks if the result matches the expectation. - """ - conversation_context = [HumanMessage(content=user_message)] - - run_input = ToolGuardRunInput( - messages=conversation_context, - metadata=ToolGuardRunInputMetaData( - tool_name=tool_name, tool_parms=tool_params, llm_client=llm_client - ), - ) - - run_output = self.middleware._run(run_input) - return run_output diff --git a/altk/pre_tool/toolguard/README.md b/altk/pre_tool/toolguard/README.md index 8b17d99..3b1f46b 100644 --- a/altk/pre_tool/toolguard/README.md +++ b/altk/pre_tool/toolguard/README.md @@ -1,52 +1,44 @@ # ToolGuards for Enforcing Agentic Policy Adherence -An agent lifecycle solution for enforcing business policy adherence in agentic workflows. Enabling this component has demonstrated up to a **20‑point improvement** in end‑to‑end agent accuracy when invoking tools. +An agent lifecycle solution for enforcing business policy adherence in agentic workflows. Enabling this component has demonstrated up to a **20‑point improvement** in end‑to‑end agent accuracy when invoking tools. This work is described in [EMNLP 2025 Towards Enforcing Company Policy Adherence in Agentic Workflows](https://arxiv.org/pdf/2507.16459), and is publiched in [this GitHub library](https://github.com/AgentToolkit/toolguard). ## Table of Contents - [Overview](#overview) -- [When to Use This Component](#when-it-is-recommended-to-use-this-component) -- [LLM Configuration Requirements](#llm-configuration-requirements) -- [Quick Start](#quick-start) -- [Parameters](#parameters) - - [Constructor Parameters](#constructor-parameters) - - [Build Phase Input Format](#build-phase-input-format) - - [Run Phase Input Format](#run-phase-input-format) - - [Run Phase Output Format](#run-phase-output-format) - +- [ToolGuardSpecComponent](#ToolGuardSpecComponent) + - [Configuarion](#component-configuarion) + - [Inputs and Outputs](#input-and-output) + - [Usage example](#usage-example) +- [ToolGuardCodeComponent](#ToolGuardCodeComponent) + - [Configuarion](#component-configuarion-1) + - [Inputs and Outputs](#input-and-output-1) + - [Usage example](#usage-example-1) ## Overview Business policies (or guidelines) are normally detailed in company documents, and have traditionally been hard-coded into automatic assistant platforms. Contemporary agentic approaches take the "best-effort" strategy, where the policies are appended to the agent's system prompt, an inherently non-deterministic approach, that does not scale effectively. Here we propose a deterministic, predictable and interpretable two-phase solution for agentic policy adherence at the tool-level: guards are executed prior to function invocation and raise alerts in case a tool-related policy deem violated. - -### Key Components - -The solution enforces policy adherence through a two-phase process: - -(1) **Buildtime**: an offline two-step pipeline that automatically maps policy fragments to the relevant tools and generates policy validation code - ToolGuards. - -(2) **Runtime**: ToolGuards are deployed within the agent's ReAct flow, and are executed after "reason" and just before "act" (agent's tool invocation). If a planned action violates a policy, the agent is prompted to self-reflect and revise its plan before proceeding. Ultimately, the deployed ToolGuards will prevent the agent from taking an action violating a policy. - - - - -## When it is Recommended to Use This Component This component enforces **pre‑tool activation policy constraints**, ensuring that agent decisions comply with business rules **before** modifying system state. This prevents policy violations such as unauthorized tool calls or unsafe parameter values. -## LLM Configuration Requirements -The **build phase** uses **two LLMs**: +## ToolGuardSpecComponent +This component gets a set of tools and a policy document and generated multiple ToolGuard specifications, known as `ToolGuardSpec`s. Each specification is attached to a tool, and it declares a precondition that must apply before invoking the tool. The specification has a `name`, `description`, list of `refernces` to the original policy document, a set of declerative `compliance_examples`, describing test cases that the toolGuard should allow the tool invocation, and `violation_examples`, where the toolGuard should raise an exception. -### 1. Reasoning LLM (Build Step 1) -Used to interpret, restructure, and classify policy text. +This componenet supports only a `build` phase. The generate specifications are returned as output, and are also saved to a specified file system directory. +The specifications are aimed to be used as input into our next component - the `ToolGuardCodeComponent` described below. -This model can be any LLM registered through: -```python -from altk.core.llm import get_llm # def get_llm(name: str) -> Type["LLMClient"] +The two components are not concatenated by design. As the geneartion involves a non-deterministic language model, the results need to be reviewed by a human. Hence, the output specification files should be reviewed and optionaly edited. For example, removing a wrong compliance example. -OPENAILiteLLMClientOutputVal = get_llm("litellm.output_val") +### Component Configuarion +This component expects an LLM client configuarion: +```python +from altk.core.llm import get_llm +LLMClient = get_llm("litellm.output_val") +llm_client = LLMClient(...) +toolguard_component = ToolGuardSpecComponent( + ToolGuardSpecComponentConfig(llm_client=llm_client) +) ``` -#### Azure example for gpt-4o: +Here is a concerete example with `litellm` and `azure`: Environment variables: ```bash export AZURE_OPENAI_API_KEY="" @@ -55,114 +47,92 @@ export AZURE_API_VERSION="2024-08-01-preview" ``` code: ```python -from altk.core.llm import get_llm # def get_llm(name: str) -> Type["LLMClient"] +from altk.core.llm import get_llm -OPENAILiteLLMClientOutputVal = get_llm("litellm.output_val") -validating_llm_client = OPENAILiteLLMClientOutputVal( +LLMClient = get_llm("litellm.output_val") +llm_client = LLMClient( model_name="gpt-4o-2024-08-06", custom_llm_provider="azure", ) - ``` -### 2. Code Generation LLM (Build Step 2) -Used only in the code generation phase to produce Python enforcement logic. -Backed by Mellea, which requires parameters aligning to: -```python -mellea.MelleaSession.start_session( - backend_name=..., - model_id=..., - backend_kwargs=... # any additional arguments -) -``` +### Input and Output +The component build input is a `ToolGuardSpecBuildInput` object containing the following fields: + * `policy_text: str`: Text of the policy document + * `tools: List[Callable] | List[BaseTool] | str`: List of available tools. Either as Python functions, methods, Langgraph Tools, or a path to an Open API specification file. + * `out_dir: str`: A directory in the local file system where the specification objects will be saved. -These map directly to environment variables: +The component build output is a list of `ToolGuardSpec`, as described above. -| Environment Variable | Mellea Parameter | Description | -| ------------------------------ | ---------------- | ------------------------------------------------------------------ | -| `TOOLGUARD_GENPY_BACKEND_NAME` | `backend_name` | Which backend to use (e.g., `openai`, `anthropic`, `vertex`, etc.) | -| `TOOLGUARD_GENPY_MODEL_ID` | `model_id` | Model name / deployment id | -| `TOOLGUARD_GENPY_ARGS` | `backend_kwargs` | JSON dict of any additional connection/LLM parameters | +### Usage example +see [simple calculator test](../../../tests/pre_tool/toolguard/test_toolguard_specs.py) -Example (Claude-4 Sonnet through OpenAI-compatible endpoint): -```bash -export TOOLGUARD_GENPY_BACKEND_NAME="openai" -export TOOLGUARD_GENPY_MODEL_ID="GCP/claude-4-sonnet" -export TOOLGUARD_GENPY_ARGS='{"base_url":"https://your-litellm-endpoint","api_key":""}' -``` -## Quick Start -See runnable example: -``` -pre-tool-guard-toolkit/examples/calculator_example -``` +## ToolGuardCodeComponent -```python -import asyncio -from altk.pre_tool.toolguard.core import ( - ToolGuardBuildInput, ToolGuardBuildInputMetaData, - ToolGuardRunInput, ToolGuardRunInputMetaData, -) -from altk.pre_tool.toolguard.pre_tool_guard import PreToolGuardComponent - -class ToolGuardExample: - def __init__(self, tools, workdir, policy_text, validating_llm_client, short=True): - self.middleware = PreToolGuardComponent(tools=tools, workdir=workdir, app_name="calculator") - build_input = ToolGuardBuildInput(metadata=ToolGuardBuildInputMetaData( - policy_text=policy_text, - short1=short, - validating_llm_client=validating_llm_client, - )) - asyncio.run(self.middleware._build(build_input)) - - def run_example(self, tool_name, tool_params): - run_input = ToolGuardRunInput( - metadata=ToolGuardRunInputMetaData(tool_name=tool_name, tool_parms=tool_params), - ) - return self.middleware._run(run_input) -``` +This components enfoorces policy adherence through a two-phase process: +(1) **Buildtime**: Given a set of `ToolGuardSpec`s, generates policy validation code - `ToolGuard`s. +Similar to ToolGuard Specifications, generated `ToolGuards` are a good start, but they may contain errors. Hence, they should be also reviewed by a human. -## Parameters +(2) **Runtime**: ToolGuards are deployed within the agent's flow, and are triggered before agent's tool invocation. They can be deployed into the agent loop, or in an MCP Gateway. +The ToolGuards checks if a planned action complies with the policy. If it violates, the agent is prompted to self-reflect and revise its plan before proceeding. -### Constructor Parameters -```python -PreToolGuardComponent(tools, workdir) -``` -| Parameter | Type | Description | -|----------|------------------|-------------| -| `tools` | `list[Callable]` | List of functions or LangChain tools to safeguard. -| `workdir` | `str` or `Path` | Writable working directory for storing build artifacts. +### Component Configuarion -### Build Phase Input Format -```python -ToolGuardBuildInput( - metadata=ToolGuardBuildInputMetaData( - policy_text="", - short1=True, - validating_llm_client= - ) -) +This component expects an LLM client configuarion. +Here is an example using a Watsonx LLM client: ``` - -### Run Phase Input Format -```python -ToolGuardRunInput( - metadata=ToolGuardRunInputMetaData( - tool_name="divide_tool", - tool_parms={"g": 3, "h": 4}, - ), - messages=[{"role": "user", "content": "Calculate 3/4"}] +from altk.core.llm.providers.ibm_watsonx_ai.ibm_watsonx_ai import WatsonxLLMClientOutputVal +llm = WatsonxLLMClientOutputVal( + model_name="meta-llama/llama-4-maverick-17b-128e-instruct-fp8", + api_key=os.getenv("WATSONX_API_KEY"), + project_id = os.getenv("WATSONX_PROJECT_ID"), + url=os.getenv("WATSONX_URL"), ) +config = ToolGuardCodeComponentConfig(llm_client=llm) +toolguard_code_component = ToolGuardCodeComponent(config) ``` - -### Run Phase Output Format -```python -ToolGuardRunOutput(output=ToolGuardRunOutputMetaData(error_message=False)) -``` -`error_message` is either `False` (valid) or a descriptive violation message. - -## License -Apache 2.0 - see LICENSE file for details. - +**Important note:** The Code component works best with *closed models* such as [GPT-4o](https://openai.com/index/hello-gpt-4o/), [Gemini](https://deepmind.google/technologies/gemini/), and [Claude](https://www.anthropic.com/claude). + +### Input and Output +The Component has two phases: +#### Build phase +An agent owner should use this API to generate ToolGuards - Python function that enforce the given business policy. +The input of the build phase is a `ToolGuardCodeBuildInput` object, containing: + * `tools: List[Callable] | List[BaseTool] | str`: List of available tools. Either as Python functions, methods, Langgraph Tools, or a path to an Open API specification file. + * `toolguard_specs: List[ToolGuardSpec]`: List of specifications, optionaly generated by `ToolGuardSpecComponent` component and reviewed. + * `out_dir: str`: A directory in the local file system where the ToolGuard objects will be saved. + +The output of the build phase is a `ToolGuardsCodeGenerationResult` object with: + * `out_dir: str`: Path to the file system where the results were saved. It is the same as the `input.out_dir`. + * `domain: RuntimeDomain`: A complex object descibing the generated APIs. For example, refernces to Python file names and class names. + * `tools: Dict[str, ToolGuardCodeResult]`: A Dictionary of the ToolGuardsResults, by the tool names. + * Each `ToolGuardCodeResult` details the name of guard Python file name and the guard function name. It also reference to the generated unit test files. + +#### Runtime phase +A running agent should use the runtime API to check if a tool call complies with the given policy. +The input of the runtime phase is a `ToolGuardCodeRunInput` object: + * `generated_guard_dir: str`: Path in the local file system where the generated guard Python code (The code that was generated during the build time, described above) is located. + * `tool_name: str`: The name of the tool that the agent is about to call + * `tool_args: Dict[str, Any]`: A dictionary of the toolcall arguments, by the argument name. + * `tool_invoker: IToolInvoker`: A proxy object that enables the guard to call other read-only tools. This is needed when the policy enforcement logic involves getting data from another tool. For example, before booking a flight, you need to check the flight status by calling the "get_flight_status" API. + The `IToolInvoker` interface contains a single method: + ``` + def invoke(self, toolname: str, arguments: Dict[str, Any], return_type: Type[T]) -> T + ``` + + ToolGuard library currently ships with three predefined ToolInvokers: + * `toolguard.runtime.ToolFunctionsInvoker(funcs: List[Callable])` where the tools are defined as plain global Python functions. + * `toolguard.runtime.ToolMethodsInvoker(obj: object)` where the tools are defined as methods in a given Python object. + * `toolguard.runtime.LangchainToolInvoker(tools: List[BaseTool])` where the tools are a list of langchain tools. + + +The outpput of the runtime phase is a `ToolGuardCodeRunOutput` object with an optional `violation` field. + * `violation: PolicyViolation | None`: Polpulated only if a violation was identified. If the toolcall complies with the policy, the violation is None. + * `violation_level: "info" | "warn" | "error"`: Severity level of a safety violation. + * `user_message: str | None`: A meaningful error message to the user (this message can be also passed to the agent reasoning phase to find an alternative next action). + +### Usage example +see [simple calculator test](../../../tests/pre_tool/toolguard/test_toolguard_code.py) diff --git a/altk/pre_tool/toolguard/__init__.py b/altk/pre_tool/toolguard/__init__.py index e69de29..d29a083 100644 --- a/altk/pre_tool/toolguard/__init__.py +++ b/altk/pre_tool/toolguard/__init__.py @@ -0,0 +1,2 @@ +from .toolguard_code_component import * +from .toolguard_spec_component import * \ No newline at end of file diff --git a/altk/pre_tool/toolguard/core/__init__.py b/altk/pre_tool/toolguard/core/__init__.py deleted file mode 100644 index 97fd319..0000000 --- a/altk/pre_tool/toolguard/core/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -from .types import ( - ToolGuardBuildInput, - ToolGuardRunInput, - ToolGuardRunOutput, - ToolGuardBuildInputMetaData, - ToolGuardRunInputMetaData, - ToolGuardRunOutputMetaData, -) - - -__all__ = [ - "ToolGuardBuildInput", - "ToolGuardRunInput", - "ToolGuardRunOutput", - "ToolGuardBuildInputMetaData", - "ToolGuardRunInputMetaData", - "ToolGuardRunOutputMetaData", -] diff --git a/altk/pre_tool/toolguard/core/types.py b/altk/pre_tool/toolguard/core/types.py deleted file mode 100644 index 1397350..0000000 --- a/altk/pre_tool/toolguard/core/types.py +++ /dev/null @@ -1,64 +0,0 @@ -from typing import Union, List - -from altk.pre_tool.toolguard.toolguard.data_types import ToolPolicy -from altk.pre_tool.toolguard.toolguard.runtime import ToolGuardsCodeGenerationResult -from altk.core.toolkit import ComponentInput, ComponentOutput -from altk.core.llm import BaseLLMClient -from pydantic import BaseModel, Field, ConfigDict - - -class ToolGuardBuildInputMetaData(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed=True) - policy_text: str = Field(description="Text of the policy document file") - short1: bool = Field(default=True, description="Run build short or long version. ") - validating_llm_client: BaseLLMClient = Field( - description="ValidatingLLMClient for build time" - ) - - -class ToolGuardBuildInput(ComponentInput): - metadata: ToolGuardBuildInputMetaData = Field( - default_factory=lambda: ToolGuardBuildInputMetaData() - ) - - -class ToolGuardRunInputMetaData(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed=True) - tool_name: str = Field(description="Tool name") - tool_parms: dict = Field(default={}, description="Tool parameters") - # llm_client: LLMClient = Field(description="LLMClient for build time") - - -class ToolGuardRunInput(ComponentInput): - metadata: ToolGuardRunInputMetaData = Field( - default_factory=lambda: ToolGuardRunInputMetaData() - ) - - -class ToolGuardBuildOutputMetaData(BaseModel): - tool_policies: List[ToolPolicy] = ( - Field( - description="List of policies specs for each tool extracted from the policy document" - ), - ) - generated_code_object: ToolGuardsCodeGenerationResult = Field( - description="root_dir of the generated code object, runtime domain and code for each tool guard" - ) - - -class ToolGuardBuildOutput(ComponentOutput): - output: ToolGuardBuildOutputMetaData = Field( - default_factory=lambda: ToolGuardBuildOutputMetaData() - ) - - -class ToolGuardRunOutputMetaData(BaseModel): - error_message: Union[str, bool] = Field( - description="Error string or False if no error occurred" - ) - - -class ToolGuardRunOutput(ComponentOutput): - output: ToolGuardRunOutputMetaData = Field( - default_factory=lambda: ToolGuardRunOutputMetaData() - ) diff --git a/altk/pre_tool/toolguard/examples/.gitignore b/altk/pre_tool/toolguard/examples/.gitignore deleted file mode 100644 index d9862e1..0000000 --- a/altk/pre_tool/toolguard/examples/.gitignore +++ /dev/null @@ -1 +0,0 @@ -work_dir_wx diff --git a/altk/pre_tool/toolguard/examples/__init__.py b/altk/pre_tool/toolguard/examples/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/altk/pre_tool/toolguard/examples/calculator_example/__init__.py b/altk/pre_tool/toolguard/examples/calculator_example/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/altk/pre_tool/toolguard/examples/calculator_example/example_tools.py b/altk/pre_tool/toolguard/examples/calculator_example/example_tools.py deleted file mode 100644 index 46f73f1..0000000 --- a/altk/pre_tool/toolguard/examples/calculator_example/example_tools.py +++ /dev/null @@ -1,92 +0,0 @@ -def add_tool(a: float, b: float) -> float: - """ - Add two numbers. - - Parameters - ---------- - a : float - The first number to add. - b : float - The second number to add. - - Returns - ------- - float - The sum of a and b. - """ - return a + b - - -def subtract_tool(c: float, d: float) -> float: - """ - Subtract one number from another. - - Parameters - ---------- - c : float - The number to subtract from. - d : float - The number to subtract. - - Returns - ------- - float - The result of a minus b. - """ - return c - d - - -def multiply_tool(e: float, f: float) -> float: - """ - Multiply two numbers. - - Parameters - ---------- - e : float - The first number. - f : float - The second number. - - Returns - ------- - float - The product of a and b. - """ - return e * f - - -def divide_tool(g: float, h: float) -> float: - """ - Divide one number by another. - - Parameters - ---------- - g : float - The dividend. - h : float - The divisor (must not be zero). - - Returns - ------- - float - The result of a divided by b. - """ - return g / h - - -def map_kdi_number(i: float) -> float: - """ - return the mapping of the numer i to it's kdi value - - Parameters - ---------- - i : float - The number to map. - - - Returns - ------- - float - The value of the dki of the given number. - """ - return 3.14 * i diff --git a/altk/pre_tool/toolguard/examples/calculator_example/run_example.py b/altk/pre_tool/toolguard/examples/calculator_example/run_example.py deleted file mode 100644 index 79ada58..0000000 --- a/altk/pre_tool/toolguard/examples/calculator_example/run_example.py +++ /dev/null @@ -1,106 +0,0 @@ -import os -from pathlib import Path - -import markdown - -from altk.pre_tool.toolguard.examples.calculator_example.example_tools import ( - add_tool, - subtract_tool, - multiply_tool, - divide_tool, - map_kdi_number, -) -from altk.pre_tool.toolguard.examples.tool_guard_example import ToolGuardExample -from altk.core.llm import get_llm - - -subdir_name = "work_dir_wx" -script_path = os.path.abspath(__file__) -script_directory = os.path.dirname(script_path) -work_dir = Path(os.path.join(script_directory, subdir_name)) -policy_doc_path = os.path.join(script_directory, "policy_document.md") -work_dir.mkdir(exist_ok=True) - -OPENAILiteLLMClientOutputVal = get_llm("litellm.output_val") -validating_llm_client = OPENAILiteLLMClientOutputVal( - model_name="gpt-4o-2024-08-06", - custom_llm_provider="azure", -) - - -tool_funcs = [add_tool, subtract_tool, multiply_tool, divide_tool, map_kdi_number] -policy_text = open(policy_doc_path, "r", encoding="utf-8").read() -policy_text = markdown.markdown(policy_text) - -tool_guard_example = ToolGuardExample( - tools=tool_funcs, - workdir=work_dir, - policy_text=policy_text, - validating_llm_client=validating_llm_client, - app_name="calculator", -) -run_output = tool_guard_example.run_example( - "divide_tool", - {"g": 3, "h": 4}, -) -print(run_output) -passed = not run_output.output.error_message -if passed: - print("success!") -else: - print("failure!") - -run_output = tool_guard_example.run_example( - "divide_tool", - {"g": 5, "h": 0}, -) -print(run_output) -passed = not run_output.output.error_message -if not passed: - print("success!") -else: - print("failure!") - -run_output = tool_guard_example.run_example( - "add_tool", - {"a": 5, "b": 44}, -) -print(run_output) -passed = not run_output.output.error_message -if passed: - print("success!") -else: - print("failure!") - -run_output = tool_guard_example.run_example( - "add_tool", - {"a": 5, "b": 73}, -) -print(run_output) -passed = not run_output.output.error_message -if not passed: - print("success!") -else: - print("failure!") - -run_output = tool_guard_example.run_example( - "multiply_tool", - {"e": 3, "f": 44}, -) -print(run_output) -passed = not run_output.output.error_message -if passed: - print("success!") -else: - print("failure!") - -run_output = tool_guard_example.run_example( - "multiply_tool", - {"e": 2, "f": 73}, -) -print(run_output) -passed = not run_output.output.error_message -if not passed: - print("success!") -else: - print("failure!") diff --git a/altk/pre_tool/toolguard/examples/full_test_calc.py b/altk/pre_tool/toolguard/examples/full_test_calc.py deleted file mode 100644 index 29d1f5f..0000000 --- a/altk/pre_tool/toolguard/examples/full_test_calc.py +++ /dev/null @@ -1,110 +0,0 @@ -import asyncio -import os -from typing import List, Dict - -import dotenv -import markdown - -from altk.pre_tool.toolguard.examples.calculator_example.example_tools import ( - divide_tool, -) -from altk.pre_tool.toolguard.toolguard.llm.tg_llmevalkit import TG_LLMEval -from altk.pre_tool.toolguard.toolguard.tool_policy_extractor.text_tool_policy_generator import ( - ToolInfo, - extract_policies, -) -from altk.pre_tool.toolguard.toolguard.core import ( - generate_guards_from_tool_policies, -) - -from altk.core.llm import get_llm - - -class FullAgent: - def __init__( - self, - app_name, - tools, - workdir, - policy_doc_path, - llm_model="gpt-4o-2024-08-06", - tools2run: List[str] | None = None, - short1=False, - ): - self.model = llm_model - self.tools = tools - self.workdir = workdir - self.policy_doc = open(policy_doc_path, "r", encoding="utf-8").read() - self.policy_doc = markdown.markdown(self.policy_doc) - self.tools2run = tools2run - self.short1 = short1 - self.app_name = app_name - self.step1_out_dir = os.path.join(self.workdir, "step1") - self.step2_out_dir = os.path.join(self.workdir, "step2") - # self.tool_registry = {tool.name: tool for tool in tools} - self.tool_registry = {tool.__name__: tool for tool in tools} - - async def build_time(self): - OPENAILiteLLMClientOutputVal = get_llm("litellm.output_val") - validating_llm_client = OPENAILiteLLMClientOutputVal( - model_name="watsonx/gpt-4o-2024-08-06", - custom_llm_provider="azure", - ) - llm = TG_LLMEval(validating_llm_client) - tools_info = [ToolInfo.from_function(tool) for tool in self.tools] - - tool_policies = await extract_policies( - self.policy_doc, tools_info, self.step1_out_dir, llm, short=True - ) - self.gen_result = await generate_guards_from_tool_policies( - self.tools, - tool_policies, - to_step2_path=self.step2_out_dir, - app_name=self.app_name, - ) - - def guard_tool(self, tool_name: str, tool_params: Dict) -> str: - print("validate_tool_node") - import sys - - code_root_dir = self.gen_result.root_dir - sys.path.insert(0, code_root_dir) - from rt_toolguard import load_toolguards - - toolguards = load_toolguards(code_root_dir) - - try: - # app_guards.check_tool_call(tool_name, tool_parms, state["messages"]) - toolguards.check_toolcall( - tool_name, tool_params, list(self.tool_registry.values()) - ) - print("ok to invoke tool") - except Exception as e: - error_message = ( - "it is against the policy to invoke tool: " - + tool_name - + " Error: " - + str(e) - ) - print(error_message) - - -if __name__ == "__main__": - dotenv.load_dotenv() - work_dir = "examples/calculator_example/output" - policy_doc_path = "examples/calculator_example/policy_document.md" - policy_doc_path = os.path.abspath(policy_doc_path) - work_dir = os.path.abspath(work_dir) - - tools = [divide_tool] # [add_tool, subtract_tool, multiply_tool, divide_tool] - fa = FullAgent( - "calculator", - tools, - work_dir, - policy_doc_path, - llm_model="gpt-4o-2024-08-06", - short1=True, - ) - asyncio.run(fa.build_time()) - fa.guard_tool("divide_tool", {"g": 5, "h": 0}) - fa.guard_tool("divide_tool", {"g": 5, "h": 4}) diff --git a/altk/pre_tool/toolguard/examples/tool_guard_example.py b/altk/pre_tool/toolguard/examples/tool_guard_example.py deleted file mode 100644 index 085516d..0000000 --- a/altk/pre_tool/toolguard/examples/tool_guard_example.py +++ /dev/null @@ -1,52 +0,0 @@ -import asyncio -import dotenv - - -from altk.pre_tool.toolguard.core import ( - ToolGuardBuildInput, - ToolGuardBuildInputMetaData, - ToolGuardRunInput, - ToolGuardRunInputMetaData, -) -from altk.pre_tool.toolguard.pre_tool_guard import PreToolGuardComponent - -# Load environment variables -dotenv.load_dotenv() - - -class ToolGuardExample: - """ - Runs examples with a ToolGuard component and validates tool invocation against policy. - """ - - def __init__( - self, tools, workdir, policy_text, validating_llm_client, app_name, short=True - ): - self.tools = tools - self.middleware = PreToolGuardComponent( - tools=tools, workdir=workdir, app_name=app_name - ) - - build_input = ToolGuardBuildInput( - metadata=ToolGuardBuildInputMetaData( - policy_text=policy_text, - short1=short, - validating_llm_client=validating_llm_client, - ) - ) - self.output = asyncio.run(self.middleware._build(build_input)) - - def run_example(self, tool_name: str, tool_params: dict): - """ - Runs a single example through ToolGuard and checks if the result matches the expectation. - """ - - run_input = ToolGuardRunInput( - metadata=ToolGuardRunInputMetaData( - tool_name=tool_name, - tool_parms=tool_params, - ), - ) - - run_output = self.middleware._run(run_input) - return run_output diff --git a/altk/pre_tool/toolguard/llm_client.py b/altk/pre_tool/toolguard/llm_client.py new file mode 100644 index 0000000..5a54b69 --- /dev/null +++ b/altk/pre_tool/toolguard/llm_client.py @@ -0,0 +1,25 @@ + +from typing import Union, cast +from altk.core.llm.types import GenerationArgs +from ...core.llm import ValidatingLLMClient, LLMClient +from toolguard.llm.tg_litellm import LanguageModelBase + + +class TG_LLMEval(LanguageModelBase): + def __init__(self, llm_client: Union[LLMClient, ValidatingLLMClient]): + super().__init__(llm_client.get_model_id()) # type: ignore + self.llm_client = llm_client + + async def generate(self, messages: list[dict]) -> str: + if isinstance(self.llm_client, ValidatingLLMClient): + llm_client = cast(ValidatingLLMClient, self.llm_client) + return await llm_client.generate_async( + prompt=messages, + schema=str, + generation_args = GenerationArgs(max_tokens=10000) + ) + + return await self.llm_client.generate_async( + prompt=messages, + generation_args = GenerationArgs(max_tokens=10000) + ) # type: ignore diff --git a/altk/pre_tool/toolguard/pre_tool_guard/__init__.py b/altk/pre_tool/toolguard/pre_tool_guard/__init__.py deleted file mode 100644 index ff307ab..0000000 --- a/altk/pre_tool/toolguard/pre_tool_guard/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .pre_tool_guard import PreToolGuardComponent - -__all__ = ["PreToolGuardComponent"] diff --git a/altk/pre_tool/toolguard/pre_tool_guard/pre_tool_guard.py b/altk/pre_tool/toolguard/pre_tool_guard/pre_tool_guard.py deleted file mode 100644 index b98ced8..0000000 --- a/altk/pre_tool/toolguard/pre_tool_guard/pre_tool_guard.py +++ /dev/null @@ -1,91 +0,0 @@ -import json -import logging -import os -from typing import Set - - -from altk.pre_tool.toolguard.toolguard.core import ( - generate_guards_from_tool_policies, -) -from altk.pre_tool.toolguard.toolguard.llm.tg_llmevalkit import TG_LLMEval -from altk.pre_tool.toolguard.toolguard.runtime import ToolFunctionsInvoker -from altk.pre_tool.toolguard.toolguard.tool_policy_extractor.text_tool_policy_generator import ( - ToolInfo, - extract_policies, -) - -from altk.core.toolkit import AgentPhase, ComponentBase - - -from altk.pre_tool.toolguard.core.types import ( - ToolGuardBuildInput, - ToolGuardRunInput, - ToolGuardRunOutput, - ToolGuardRunOutputMetaData, - ToolGuardBuildOutput, - ToolGuardBuildOutputMetaData, -) - -logger = logging.getLogger(__name__) - - -class PreToolGuardComponent(ComponentBase): - def __init__(self, tools, workdir, app_name): - super().__init__() - self._tools = tools - self._tool_registry = {tool.__name__: tool for tool in self._tools} - self._workdir = workdir - self._step1_dir = os.path.join(self._workdir, "Step_1") - self._step2_dir = os.path.join(self._workdir, "Step_2") - self._app_name = app_name - self._tool_policies = None - self._gen_result = None - - @classmethod - def supported_phases(cls) -> Set[AgentPhase]: - """Return the supported agent phases.""" - return {AgentPhase.BUILDTIME, AgentPhase.RUNTIME} - - async def _build(self, data: ToolGuardBuildInput) -> ToolGuardBuildOutput: - llm = TG_LLMEval(data.metadata.validating_llm_client) - tools_info = [ToolInfo.from_function(tool) for tool in self._tools] - - self._tool_policies = await extract_policies( - data.metadata.policy_text, tools_info, self._step1_dir, llm, short=True - ) - self._gen_result = await generate_guards_from_tool_policies( - self._tools, - self._tool_policies, - to_step2_path=self._step2_dir, - app_name=self._app_name, - ) - output = ToolGuardBuildOutputMetaData( - tool_policies=self._tool_policies, generated_code_object=self._gen_result - ) - return ToolGuardBuildOutput(output=output) - - def _run(self, data: ToolGuardRunInput) -> ToolGuardRunOutput: - import sys - - code_root_dir = self._gen_result.root_dir - sys.path.insert(0, code_root_dir) - tool_name = data.metadata.tool_name - tool_params = data.metadata.tool_parms - from rt_toolguard import load_toolguards - - app_guards = load_toolguards(code_root_dir) - - try: - app_guards.check_toolcall( - tool_name, - tool_params, - ToolFunctionsInvoker(list(self._tool_registry.values())), - ) - error_message = False - except Exception as e: - error_message = ( - f"It is against the policy to invoke tool: {tool_name}({json.dumps(tool_params)}) Error: " - + str(e) - ) - output = ToolGuardRunOutputMetaData(error_message=error_message) - return ToolGuardRunOutput(output=output) diff --git a/altk/pre_tool/toolguard/toolguard/__init__.py b/altk/pre_tool/toolguard/toolguard/__init__.py deleted file mode 100644 index 59b7732..0000000 --- a/altk/pre_tool/toolguard/toolguard/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# from .core import - -from altk.pre_tool.toolguard.toolguard.core import ( - build_toolguards, - extract_policies, - generate_guards_from_tool_policies, -) diff --git a/altk/pre_tool/toolguard/toolguard/common/__init__.py b/altk/pre_tool/toolguard/toolguard/common/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/altk/pre_tool/toolguard/toolguard/common/array.py b/altk/pre_tool/toolguard/toolguard/common/array.py deleted file mode 100644 index 412df9c..0000000 --- a/altk/pre_tool/toolguard/toolguard/common/array.py +++ /dev/null @@ -1,62 +0,0 @@ -import functools -from typing import Callable, List, TypeVar, Any - - -def flatten(arr_arr): - return [b for bs in arr_arr for b in bs] - - -def break_array_into_chunks(arr: List, chunk_size: int) -> List[List[Any]]: - res = [] - for i, v in enumerate(arr): - if i % chunk_size == 0: - cur = [] - res.append(cur) - cur.append(v) - return res - - -def sum(array): - return functools.reduce(lambda a, b: a + b, array) if len(array) > 0 else 0 - - -T = TypeVar("T") - - -def find(array: List[T], pred: Callable[[T], bool]): - for item in array: - if pred(item): - return item - - -# remove duplicates and preserve ordering -T = TypeVar("T") - - -def remove_duplicates(array: List[T]) -> List[T]: - res = [] - visited = set() - for item in array: - if item not in visited: - res.append(item) - visited.add(item) - return res - - -def not_none(array: List[T]) -> List[T]: - return [item for item in array if item is not None] - - -def split_array(arr: List[T], delimiter: T) -> List[List[T]]: - result = [] - temp = [] - for item in arr: - if item == delimiter: - if temp: # Avoid adding empty lists - result.append(temp) - temp = [] - else: - temp.append(item) - if temp: # Add the last group if not empty - result.append(temp) - return result diff --git a/altk/pre_tool/toolguard/toolguard/common/dict.py b/altk/pre_tool/toolguard/toolguard/common/dict.py deleted file mode 100644 index 590e767..0000000 --- a/altk/pre_tool/toolguard/toolguard/common/dict.py +++ /dev/null @@ -1,108 +0,0 @@ -from typing import Dict, Any, Callable -import hashlib -import json -from collections import deque - - -def dict_deep_merge(trg_dct: Dict, merge_dct: Dict): - for k, v in merge_dct.items(): - trg_v = trg_dct.get(k) - if k in trg_dct and isinstance(trg_v, dict) and isinstance(v, dict): - dict_deep_merge(trg_v, v) - elif isinstance(trg_v, list) and isinstance(v, list): - if all([type(item) is dict for item in v]) and all( - [type(item) is dict for item in trg_v] - ): # both are lists of objects - for a, b in zip(trg_v, v): - dict_deep_merge(a, b) - else: - for item in v: - if item not in trg_dct[k]: - trg_dct[k].append(item) - else: - trg_dct[k] = v - - -def get_keys_recursive(d: Dict): - keys = [] - for key, value in d.items(): - keys.append(key) - if isinstance(value, dict): - keys.extend(get_keys_recursive(value)) - elif isinstance(value, list): - for item in value: - if isinstance(item, dict): - keys.extend(get_keys_recursive(item)) - return keys - - -# adapted from https://www.doc.ic.ac.uk/~nuric/posts/coding/how-to-hash-a-dictionary-in-python/ -def dict_hash(dictionary: Dict[str, Any]) -> int: - encoded = json.dumps(dictionary, sort_keys=True).encode() - return int(hashlib.sha256(encoded).hexdigest(), 16) - - -def visit_all(d: Any, cb: Callable[[Dict, Any], bool]): - if type(d) is dict: - for k in list( - d - ): # list() to avoid RuntimeError: dictionary changed size during iteration - changed = cb(d, k) - if not changed: - visit_all(d[k], cb) - if type(d) is list: - for item in d: - visit_all(item, cb) - - -def resolve_ref(ref, root_schema, resolved_refs): - """Resolve a $ref to its actual definition in the schema.""" - if ref in resolved_refs: - return resolved_refs[ref] - - resolved_schema = find_ref(root_schema, ref) - resolved_refs[ref] = resolved_schema - return substitute_refs(resolved_schema, root_schema, resolved_refs) - - -def substitute_refs(schema, root_schema=None, resolved_refs=None): - """Substitute all $refs in the JSON schema with their definitions.""" - if root_schema is None: - root_schema = schema - if resolved_refs is None: - resolved_refs = {} - - if isinstance(schema, dict): - if "$ref" in schema: - ref = schema["$ref"] - if ref in resolved_refs: - return resolved_refs[ref] - resolved = resolve_ref(ref, root_schema, resolved_refs) - resolved_refs[ref] = resolved - return resolved - else: - return { - k: substitute_refs(v, root_schema, resolved_refs) - for k, v in schema.items() - } - elif isinstance(schema, list): - return [substitute_refs(item, root_schema, resolved_refs) for item in schema] - else: - return schema - - -def find_ref(doc: dict, ref: str): - q = deque(ref.split("/")) - cur = doc - while q: - field = q.popleft() - if field == "#": - cur = doc - continue - if field in cur: - cur = cur[field] - else: - return None - if "$ref" in cur: - return find_ref(doc, cur.get("$ref")) # recursive. infinte loops? - return cur diff --git a/altk/pre_tool/toolguard/toolguard/common/http.py b/altk/pre_tool/toolguard/toolguard/common/http.py deleted file mode 100644 index 2cb37be..0000000 --- a/altk/pre_tool/toolguard/toolguard/common/http.py +++ /dev/null @@ -1,55 +0,0 @@ -from enum import Enum -from functools import cache -from typing import List - - -MEDIA_TYPE_APP_JSON = "application/json" -MEDIA_TYPE_MULTIPART_FORM = "multipart/form-data" -MEDIA_TYPE_APP_FORM = "application/x-www-form-urlencoded" - - -class StrEnum(str, Enum): - """An abstract base class for string-based enums.""" - - pass - - -class HttpMethod(StrEnum): - GET = "GET" - POST = "POST" - PUT = "PUT" - DELETE = "DELETE" - PATCH = "PATCH" - - def __eq__(self, value: object) -> bool: - return str(value).upper() == self.value - - def __ne__(self, value: object) -> bool: - return not self.__eq__(value) - - def __hash__(self): - return self.value.__hash__() - - @classmethod - @cache - def list(cls) -> List["HttpMethod"]: - return list(map(lambda c: c.value, cls)) - - -def is_valid_http_method(val: str): - return val and val.upper() in HttpMethod.list() - - -PARAM_API_KEY = "api_key" - -AUTH_HEADER = "Authorization" - -SECURITY_COMPONENT_TYPE_API_KEY = "apiKey" -SECURITY_COMPONENT_SCHEME_BEARER = "bearer" -SECURITY_COMPONENT_SCHEME_BASIC = "basic" - -SECURITY_COMPONENT_BEARER = { - "type": "http", - "scheme": SECURITY_COMPONENT_SCHEME_BEARER, - "bearerFormat": "JWT", -} diff --git a/altk/pre_tool/toolguard/toolguard/common/jschema.py b/altk/pre_tool/toolguard/toolguard/common/jschema.py deleted file mode 100644 index 867a900..0000000 --- a/altk/pre_tool/toolguard/toolguard/common/jschema.py +++ /dev/null @@ -1,42 +0,0 @@ -from enum import Enum -from typing import Any, Dict, List, Optional, Union - - -from altk.pre_tool.toolguard.toolguard.common.ref import DocumentWithRef, Reference - - -class StrEnum(str, Enum): - """An abstract base class for string-based enums.""" - - pass - - -class JSONSchemaTypes(StrEnum): - string = "string" - number = "number" - integer = "integer" - boolean = "boolean" - array = "array" - object = "object" - null = "null" - - -class JSchema(DocumentWithRef): - type: Optional[JSONSchemaTypes] = None - properties: Optional[Dict[str, Union[Reference, "JSchema"]]] = None - items: Optional[Union[Reference, "JSchema"]] = None - additionalProperties: Optional[Union["JSchema", bool]] = None - format: Optional[str] = None - enum: Optional[list] = None - default: Optional[Any] = None - description: Optional[str] = None - example: Optional[Any] = None - required: Optional[List[str]] = None - allOf: Optional[List[Union[Reference, "JSchema"]]] = None - anyOf: Optional[List[Union[Reference, "JSchema"]]] = None - nullable: Optional[bool] = ( - None # in OPenAPISpec https://swagger.io/docs/specification/v3_0/data-models/data-types/#null - ) - - def __str__(self) -> str: - return self.model_dump_json(exclude_none=True, indent=2) diff --git a/altk/pre_tool/toolguard/toolguard/common/llm_py.py b/altk/pre_tool/toolguard/toolguard/common/llm_py.py deleted file mode 100644 index b6ab4fa..0000000 --- a/altk/pre_tool/toolguard/toolguard/common/llm_py.py +++ /dev/null @@ -1,12 +0,0 @@ -import re - -PYTHON_PATTERN = r"^```python\s*\n([\s\S]*)\n```" - - -def get_code_content(llm_code) -> str: - code = llm_code.replace("\\n", "\n") - match = re.match(PYTHON_PATTERN, code) - if match: - return match.group(1) - - return code diff --git a/altk/pre_tool/toolguard/toolguard/common/open_api.py b/altk/pre_tool/toolguard/toolguard/common/open_api.py deleted file mode 100644 index 7e5cad1..0000000 --- a/altk/pre_tool/toolguard/toolguard/common/open_api.py +++ /dev/null @@ -1,207 +0,0 @@ -from enum import Enum -from pydantic import BaseModel, Field, HttpUrl -from typing import List, Dict, Optional, Any, TypeVar, Union -import json -import yaml - -from altk.pre_tool.toolguard.toolguard.common.dict import find_ref -from altk.pre_tool.toolguard.toolguard.common.http import MEDIA_TYPE_APP_JSON -from altk.pre_tool.toolguard.toolguard.common.jschema import JSchema -from altk.pre_tool.toolguard.toolguard.common.ref import Reference - - -class Contact(BaseModel): - name: Optional[str] = None - url: Optional[HttpUrl] = None - email: Optional[str] = None - - -class License(BaseModel): - name: str - identifier: Optional[str] = None - url: Optional[HttpUrl] = None - - -class Info(BaseModel): - title: str - summary: Optional[str] = None - description: Optional[str] = None - termsOfService: Optional[HttpUrl] = None - contact: Optional[Contact] = None - license: Optional[License] = None - version: str - - -class ServerVariable(BaseModel): - enum: Optional[List[str]] = None - default: str - description: Optional[str] = None - - -class Server(BaseModel): - url: str - description: Optional[str] = None - variables: Optional[Dict[str, ServerVariable]] = None - - -class ExternalDocumentation(BaseModel): - description: Optional[str] = None - url: HttpUrl - - -class Tag(BaseModel): - name: str - description: Optional[str] = None - externalDocs: Optional[ExternalDocumentation] = None - - -class MediaType(BaseModel): - schema_: Optional[Union[Reference, JSchema]] = Field(None, alias="schema") - example: Optional[Any] = None - examples: Optional[Dict[str, Any]] = None - - -class RequestBody(BaseModel): - description: Optional[str] = None - required: Optional[bool] = None - content: Optional[Dict[str, MediaType]] = None - - @property - def content_json(self): - if self.content: - return self.content.get(MEDIA_TYPE_APP_JSON) - - -class Response(BaseModel): - description: Optional[str] = None - content: Optional[Dict[str, MediaType]] = None - - @property - def content_json(self): - if self.content: - return self.content.get(MEDIA_TYPE_APP_JSON) - - -class StrEnum(str, Enum): - """An abstract base class for string-based enums.""" - - pass - - -class ParameterIn(StrEnum): - query = "query" - header = "header" - cookie = "cookie" - path = "path" - - -class Parameter(BaseModel): - name: str - description: Optional[str] = None - in_: ParameterIn = Field(ParameterIn.query, alias="in") - required: Optional[bool] = None - schema_: Optional[Union[Reference, JSchema]] = Field(None, alias="schema") - - -class Operation(BaseModel): - summary: Optional[str] = None - description: Optional[str] = None - operationId: Optional[str] = None - tags: Optional[List[str]] = None - parameters: Optional[List[Union[Reference, Parameter]]] = None - requestBody: Optional[Union[Reference, RequestBody]] = None - responses: Optional[Dict[str, Union[Reference, Response]]] = None - security: Optional[Dict[str, List[str]]] = None - - -class PathItem(BaseModel): - summary: Optional[str] = None - description: Optional[str] = None - servers: Optional[List[Server]] = None - parameters: Optional[List[Union[Reference, Parameter]]] = None - get: Optional[Operation] = None - put: Optional[Operation] = None - post: Optional[Operation] = None - delete: Optional[Operation] = None - options: Optional[Operation] = None - head: Optional[Operation] = None - patch: Optional[Operation] = None - trace: Optional[Operation] = None - - @property - def operations(self): - d = { - "get": self.get, - "put": self.put, - "post": self.post, - "delete": self.delete, - "options": self.options, - "head": self.head, - "patch": self.patch, - "trace": self.trace, - } - return {k: v for k, v in d.items() if v is not None} - - -class Components(BaseModel): - schemas: Optional[Dict[str, JSchema]] = None - responses: Optional[Dict[str, Response]] = None - parameters: Optional[Dict[str, Parameter]] = None - examples: Optional[Dict[str, Any]] = None - requestBodies: Optional[Dict[str, RequestBody]] = None - headers: Optional[Dict[str, Any]] = None - securitySchemes: Optional[Dict[str, Any]] = None - links: Optional[Dict[str, Any]] = None - callbacks: Optional[Dict[str, Any]] = None - pathItems: Optional[Dict[str, PathItem]] = None - - -BaseModelT = TypeVar("BaseModelT", bound=BaseModel) - - -class OpenAPI(BaseModel): - openapi: str = Field(..., pattern=r"^3\.\d\.\d+(-.+)?$") - info: Info - jsonSchemaDialect: Optional[HttpUrl] = ( - "https://spec.openapis.org/oas/3.1/dialect/WORK-IN-PROGRESS" - ) - servers: Optional[List[Server]] = [Server(url="/")] - paths: Dict[str, Union[Reference, PathItem]] = {} - webhooks: Optional[Dict[str, PathItem]] = None - components: Optional[Components] = None - security: Optional[List[Dict[str, List[str]]]] = None # Refined to List of Dicts - tags: Optional[List[Tag]] = None - externalDocs: Optional[ExternalDocumentation] = None - - def get_operation_by_operationId(self, operationId: str) -> Operation | None: - for path_item in self.paths.values(): - for op in path_item.operations.values(): - if op.operationId == operationId: - return op - - def resolve_ref( - self, obj: Reference | BaseModelT | None, object_type: type[BaseModelT] - ) -> BaseModelT | None: - if isinstance(obj, Reference): - tmp = find_ref(self.model_dump(), obj.ref) - return object_type.model_validate(tmp) - return obj - - def save(self, file_name: str): - if file_name.endswith(".json"): - with open(file_name, "w", encoding="utf-8") as f: - f.write( - self.model_dump_json(indent=2, by_alias=True, exclude_none=True) - ) - return - # TODO yaml - raise NotImplementedError() - - -def read_openapi(file_path: str) -> OpenAPI: - with open(file_path, "r") as file: - if file_path.endswith("json"): - d = json.load(file) - else: - d = yaml.safe_load(file) - return OpenAPI.model_validate(d, strict=False) diff --git a/altk/pre_tool/toolguard/toolguard/common/py.py b/altk/pre_tool/toolguard/toolguard/common/py.py deleted file mode 100644 index 880ffb8..0000000 --- a/altk/pre_tool/toolguard/toolguard/common/py.py +++ /dev/null @@ -1,88 +0,0 @@ -import os -import inspect -from typing import Callable -import sys -from pathlib import Path -from contextlib import contextmanager - -from altk.pre_tool.toolguard.toolguard.common.str import to_snake_case - - -def py_extension(filename: str) -> str: - return filename if filename.endswith(".py") else filename + ".py" - - -def un_py_extension(filename: str) -> str: - return filename[:-3] if filename.endswith(".py") else filename - - -def path_to_module(file_path: str) -> str: - assert file_path - parts = file_path.split("/") - if parts[-1].endswith(".py"): - parts[-1] = un_py_extension(parts[-1]) - return ".".join([to_snake_case(part) for part in parts]) - - -def module_to_path(module: str) -> str: - parts = module.split(".") - return os.path.join(*parts[:-1], py_extension(parts[-1])) - - -def unwrap_fn(fn: Callable) -> Callable: - return fn.func if hasattr(fn, "func") else fn - - -@contextmanager -def temp_python_path(path: str): - path = str(Path(path).resolve()) - if path not in sys.path: - sys.path.insert(0, path) - try: - yield - finally: - sys.path.remove(path) - else: - # Already in sys.path, no need to remove - yield - - -def extract_docstr_args(func: Callable) -> str: - doc = inspect.getdoc(func) - if not doc: - return "" - - lines = doc.splitlines() - args_start = None - for i, line in enumerate(lines): - if line.strip().lower() == "args:": - args_start = i - break - - if args_start is None: - return "" - - # List of known docstring section headers - next_sections = { - "returns:", - "raises:", - "examples:", - "notes:", - "attributes:", - "yields:", - } - - # Capture lines after "Args:" that are indented - args_lines = [] - for line in lines[args_start + 1 :]: - # Stop if we hit a new section (like "Returns:", "Raises:", etc.) - stripped = line.strip().lower() - if stripped in next_sections: - break - args_lines.append(" " * 8 + line.strip()) - - # Join all lines into a single string - if not args_lines: - return "" - - return "\n".join(args_lines) diff --git a/altk/pre_tool/toolguard/toolguard/common/py_doc_str.py b/altk/pre_tool/toolguard/toolguard/common/py_doc_str.py deleted file mode 100644 index 26c6188..0000000 --- a/altk/pre_tool/toolguard/toolguard/common/py_doc_str.py +++ /dev/null @@ -1,70 +0,0 @@ -import inspect -import re -from typing import Callable - - -def extract_docstr_args(func: Callable) -> str: - doc = inspect.getdoc(func) - if not doc: - return "" - - lines = doc.splitlines() - - def args_start_line(): - for i, line in enumerate(lines): - if line.strip().lower() == "args:": # Google style docstr - return i + 1 - if ( - line.strip().lower().startswith(":param ") - ): # Sphinx-style docstring. https://sphinx-rtd-tutorial.readthedocs.io/en/latest/docstrings.html - return i - - args_start = args_start_line() - if args_start is None: - return "" - - # List of known docstring section headers - next_sections = { - "returns:", - "raises:", - "examples:", - "notes:", - "attributes:", - "yields:", - } - - # Capture lines after "Args:" that are indented - args_lines = [] - indent = " " * 4 * 2 - for line in lines[args_start:]: - # Stop if we hit a new section (like "Returns:", "Raises:", etc.) - stripped = line.strip().lower() - if stripped in next_sections or stripped.startswith(":return:"): - break - - args_lines.append(indent + sphinx_param_to_google(line.strip())) - - # Join all lines into a single string - if not args_lines: - return "" - - return "\n".join(args_lines) - - -def sphinx_param_to_google(line: str) -> str: - """ - Convert a single Sphinx-style ':param' line to Google style. - - Args: - line: A Sphinx param line, e.g. - ':param user_id: The unique identifier of the user.' - - Returns: - str: Google style equivalent, e.g. - 'user_id: The unique identifier of the user.' - """ - m = re.match(r"\s*:param\s+(\w+)\s*:\s*(.*)", line) - if not m: - return line - name, desc = m.groups() - return f"{name}: {desc}" diff --git a/altk/pre_tool/toolguard/toolguard/common/ref.py b/altk/pre_tool/toolguard/toolguard/common/ref.py deleted file mode 100644 index 00308e6..0000000 --- a/altk/pre_tool/toolguard/toolguard/common/ref.py +++ /dev/null @@ -1,21 +0,0 @@ -from typing import TypeVar -from pydantic import BaseModel, Field - -from altk.pre_tool.toolguard.toolguard.common.dict import find_ref - - -class Reference(BaseModel): - ref: str = Field(..., alias="$ref") - - -BaseModelT = TypeVar("BaseModelT", bound=BaseModel) - - -class DocumentWithRef(BaseModel): - def resolve_ref( - self, obj: Reference | BaseModelT | None, object_type: type[BaseModelT] - ) -> BaseModelT | None: - if isinstance(obj, Reference): - tmp = find_ref(self.model_dump(), obj.ref) - return object_type.model_validate(tmp) - return obj diff --git a/altk/pre_tool/toolguard/toolguard/common/str.py b/altk/pre_tool/toolguard/toolguard/common/str.py deleted file mode 100644 index d9a8a9b..0000000 --- a/altk/pre_tool/toolguard/toolguard/common/str.py +++ /dev/null @@ -1,21 +0,0 @@ -def to_camel_case(snake_str: str) -> str: - return ( - snake_str.replace("_", " ") - .title() - .replace(" ", "") - .replace("-", "_") - .replace("'", "_") - .replace(",", "_") - .replace(".", "_") - ) - - -def to_snake_case(human_name: str) -> str: - return ( - human_name.lower() - .replace(" ", "_") - .replace("-", "_") - .replace("'", "_") - .replace(",", "_") - .replace(".", "_") - ) diff --git a/altk/pre_tool/toolguard/toolguard/core.py b/altk/pre_tool/toolguard/toolguard/core.py deleted file mode 100644 index ebf168a..0000000 --- a/altk/pre_tool/toolguard/toolguard/core.py +++ /dev/null @@ -1,115 +0,0 @@ -import os -from os.path import join -from typing import Callable, List, Optional - -import json -import logging - -from altk.pre_tool.toolguard.toolguard.gen_py.gen_toolguards import ( - generate_toolguards_from_openapi, - generate_toolguards_from_functions, -) -from altk.pre_tool.toolguard.toolguard.llm.i_tg_llm import I_TG_LLM -from altk.pre_tool.toolguard.toolguard.runtime import ToolGuardsCodeGenerationResult -from altk.pre_tool.toolguard.toolguard.data_types import ToolPolicy, load_tool_policy -from altk.pre_tool.toolguard.toolguard.tool_policy_extractor.create_oas_summary import ( - OASSummarizer, -) -from altk.pre_tool.toolguard.toolguard.tool_policy_extractor.text_tool_policy_generator import ( - ToolInfo, - extract_policies, -) - - -logger = logging.getLogger(__name__) - - -async def build_toolguards( - policy_text: str, - tools: List[Callable] | str, - step1_out_dir: str, - step2_out_dir: str, - step1_llm: I_TG_LLM, - app_name: str = "my_app", - tools2run: List[str] | None = None, - short1=True, -) -> ToolGuardsCodeGenerationResult: - if isinstance(tools, list): # supports list of functions or list of langgraph tools - tools_info = [ToolInfo.from_function(tool) for tool in tools] - tool_policies = await extract_policies( - policy_text, tools_info, step1_out_dir, step1_llm, tools2run, short1 - ) - return await generate_guards_from_tool_policies( - tools, tool_policies, step2_out_dir, app_name, None, tools2run - ) - - if isinstance(tools, str): # Backward compatibility to support OpenAPI specs - oas_path = tools - with open(oas_path, "r", encoding="utf-8") as file: - oas = json.load(file) - summarizer = OASSummarizer(oas) - tools_info = summarizer.summarize() - tool_policies = await extract_policies( - policy_text, tools_info, step1_out_dir, step1_llm, tools2run, short1 - ) - return await generate_guards_from_tool_policies_oas( - oas_path, tool_policies, step2_out_dir, app_name, tools2run - ) - - raise ValueError("Unknown tools") - - -async def generate_guards_from_tool_policies( - funcs: List[Callable], - tool_policies: List[ToolPolicy], - to_step2_path: str, - app_name: str, - lib_names: Optional[List[str]] = None, - tool_names: Optional[List[str]] = None, -) -> ToolGuardsCodeGenerationResult: - os.makedirs(to_step2_path, exist_ok=True) - - tool_policies = [ - policy - for policy in tool_policies - if (not tool_names) or (policy.tool_name in tool_names) - ] - return await generate_toolguards_from_functions( - app_name, tool_policies, to_step2_path, funcs=funcs, module_roots=lib_names - ) - - -async def generate_guards_from_tool_policies_oas( - oas_path: str, - tool_policies: List[ToolPolicy], - to_step2_path: str, - app_name: str, - tool_names: Optional[List[str]] = None, -) -> ToolGuardsCodeGenerationResult: - os.makedirs(to_step2_path, exist_ok=True) - - tool_policies = [ - policy - for policy in tool_policies - if (not tool_names) or (policy.tool_name in tool_names) - ] - return await generate_toolguards_from_openapi( - app_name, tool_policies, to_step2_path, oas_path - ) - - -def load_policies_in_folder( - folder: str, -) -> List[ToolPolicy]: - files = [ - f - for f in os.listdir(folder) - if os.path.isfile(join(folder, f)) and f.endswith(".json") - ] - tool_policies = [] - for file in files: - tool_name = file[: -len(".json")] - policy = load_tool_policy(join(folder, file), tool_name) - if policy.policy_items: - tool_policies.append(policy) - return tool_policies diff --git a/altk/pre_tool/toolguard/toolguard/data_types.py b/altk/pre_tool/toolguard/toolguard/data_types.py deleted file mode 100644 index ff37a0d..0000000 --- a/altk/pre_tool/toolguard/toolguard/data_types.py +++ /dev/null @@ -1,127 +0,0 @@ -import json -import os -from pathlib import Path -from pydantic import BaseModel, Field -from typing import List, Optional - -DEBUG_DIR = "debug" -TESTS_DIR = "tests" -RESULTS_FILENAME = "result.json" -API_PARAM = "api" - - -class FileTwin(BaseModel): - file_name: str - content: str - - def save(self, folder: str) -> "FileTwin": - full_path = os.path.join(folder, self.file_name) - parent = Path(full_path).parent - os.makedirs(parent, exist_ok=True) - with open(full_path, "w") as file: - file.write(self.content) - return self - - def save_as(self, folder: str, file_name: str) -> "FileTwin": - file_path = os.path.join(folder, file_name) - with open(file_path, "w") as file: - file.write(self.content) - return FileTwin(file_name=file_name, content=self.content) - - @staticmethod - def load_from(folder: str, file_path: str) -> "FileTwin": - with open(os.path.join(folder, file_path), "r") as file: - data = file.read() - return FileTwin(file_name=file_path, content=data) - - -class ToolPolicyItem(BaseModel): - name: str = Field(..., description="Policy item name") - description: str = Field(..., description="Policy item description") - references: List[str] = Field(..., description="original texts") - compliance_examples: Optional[List[str]] = Field( - ..., description="Example of cases that comply with the policy" - ) - violation_examples: Optional[List[str]] = Field( - ..., description="Example of cases that violate the policy" - ) - - def to_md_bulltets(self, items: List[str]) -> str: - s = "" - for item in items: - s += f"* {item}\n" - return s - - def __str__(self) -> str: - s = "#### Policy item " + self.name + "\n" - s += f"{self.description}\n" - if self.compliance_examples: - s += f"##### Positive examples\n{self.to_md_bulltets(self.compliance_examples)}" - if self.violation_examples: - s += f"##### Negative examples\n{self.to_md_bulltets(self.violation_examples)}" - return s - - -class ToolPolicy(BaseModel): - tool_name: str = Field(..., description="Name of the tool") - policy_items: List[ToolPolicyItem] = Field( - ..., - description="Policy items. All (And logic) policy items must hold whehn invoking the tool.", - ) - - -def load_tool_policy(file_path: str, tool_name: str) -> ToolPolicy: - with open(file_path, "r") as file: - d = json.load(file) - - items = [ - ToolPolicyItem( - name=item.get("policy_name"), - description=item.get("description"), - references=item.get("references"), - compliance_examples=item.get("compliance_examples"), - violation_examples=item.get("violating_examples"), - ) - for item in d.get("policies", []) - if not item.get("skip") - ] - return ToolPolicy(tool_name=tool_name, policy_items=items) - - -class Domain(BaseModel): - app_name: str = Field(..., description="Application name") - toolguard_common: FileTwin = Field( - ..., description="Pydantic data types used by toolguard framework." - ) - app_types: FileTwin = Field( - ..., description="Data types defined used in the application API as payloads." - ) - app_api_class_name: str = Field(..., description="Name of the API class name.") - app_api: FileTwin = Field( - ..., description="Python class (abstract) containing all the API signatures." - ) - app_api_size: int = Field(..., description="Number of functions in the API") - - -class RuntimeDomain(Domain): - app_api_impl_class_name: str = Field( - ..., description="Python class (implementaton) class name." - ) - app_api_impl: FileTwin = Field( - ..., description="Python class containing all the API method implementations." - ) - - def get_definitions_only(self): - return Domain.model_validate(self.model_dump()) - - -class PolicyViolationException(Exception): - _msg: str - - def __init__(self, message: str): - super().__init__(message) - self._msg = message - - @property - def message(self): - return self._msg diff --git a/altk/pre_tool/toolguard/toolguard/gen_py/__init__.py b/altk/pre_tool/toolguard/toolguard/gen_py/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/altk/pre_tool/toolguard/toolguard/gen_py/api_extractor.py b/altk/pre_tool/toolguard/toolguard/gen_py/api_extractor.py deleted file mode 100644 index 5feb9a4..0000000 --- a/altk/pre_tool/toolguard/toolguard/gen_py/api_extractor.py +++ /dev/null @@ -1,784 +0,0 @@ -from dataclasses import is_dataclass -from enum import Enum -import inspect -import os -import textwrap -from types import FunctionType, UnionType -import types -from typing import ( - Callable, - DefaultDict, - Dict, - List, - Literal, - Optional, - Set, - Tuple, - get_type_hints, - get_origin, - get_args, -) -from typing import Annotated, Union -from collections import defaultdict, deque -import typing -from altk.pre_tool.toolguard.toolguard.common.py import module_to_path, unwrap_fn -from altk.pre_tool.toolguard.toolguard.data_types import FileTwin - -Dependencies = DefaultDict[type, Set[type]] - - -class APIExtractor: - def __init__(self, py_path: str, include_module_roots: Optional[List[str]] = None): - if not include_module_roots: - include_module_roots = [] - self.py_path = py_path - self.include_module_roots = include_module_roots - - def extract_from_functions( - self, - funcs: List[Callable], - interface_name: str, - interface_module_name: str, - types_module_name: str, - impl_module_name: str, - impl_class_name: str, - ) -> Tuple[FileTwin, FileTwin, FileTwin]: - funcs = [unwrap_fn(func) for func in funcs] - assert all([_is_global_or_class_function(func) for func in funcs]) - - os.makedirs(self.py_path, exist_ok=True) - - # used types - types = FileTwin( - file_name=module_to_path(types_module_name), - content=self._generate_types_file( - *self._collect_all_types_from_functions(funcs) - ), - ).save(self.py_path) - - # API interface - interface = FileTwin( - file_name=module_to_path(interface_module_name), - content=self._generate_interface_from_functions( - funcs, interface_name, types_module_name - ), - ).save(self.py_path) - - # API impl interface - impl = FileTwin( - file_name=module_to_path(impl_module_name), - content=self._generate_impl_from_functions( - funcs, - impl_class_name, - interface_module_name, - interface_name, - types_module_name, - ), - ).save(self.py_path) - - return interface, types, impl - - def extract_from_class( - self, - typ: type, - *, - interface_name: Optional[str] = None, - interface_module_name: Optional[str] = None, - types_module_name: Optional[str] = None, - ) -> Tuple[FileTwin, FileTwin]: - """Extract interface and types from a class and save to files.""" - class_name = _get_type_name(typ) - interface_name = interface_name or "I_" + class_name - interface_module_name = interface_module_name or f"I_{class_name}".lower() - types_module_name = types_module_name or f"{class_name}_types".lower() - - os.makedirs(self.py_path, exist_ok=True) - - # Types - collected, dependencies = self._collect_all_types_from_class(typ) - types_content = self._generate_types_file(collected, dependencies) - types_file = module_to_path(types_module_name) - types = FileTwin(file_name=types_file, content=types_content).save(self.py_path) - - # API interface - if_content = self._generate_interface_from_class( - typ, interface_name, types_module_name - ) - if_file = module_to_path(interface_module_name) - interface = FileTwin(file_name=if_file, content=if_content).save(self.py_path) - - return interface, types - - def _generate_interface_from_class( - self, typ: type, interface_name: str, types_module: str - ) -> str: - # Start building the interface - lines = [ - "# Auto-generated class interface", - "from typing import * # type: ignore", - "from abc import ABC, abstractmethod", - f"from {types_module} import *", - "", - ] - - lines.append(f"class {interface_name}(ABC):") # Abstract class - - # Add class docstring if available - if typ.__doc__: - docstring = typ.__doc__.strip() - if docstring: - lines.append(' """') - # Handle multi-line docstrings - for line in docstring.split("\n"): - lines.append(f" {line.strip()}") - lines.append(' """') - - # Get all methods - methods = [] - for name, method in inspect.getmembers(typ, predicate=inspect.isfunction): - if not name.startswith("_"): - methods.append((name, method)) - - if not methods: - lines.append(" pass") - else: - for method_name, method in methods: - # Add method docstring and signature - lines.append(" @abstractmethod") - method_lines = self._get_function_with_docstring(method, method_name) - lines.extend([line if line else "" for line in method_lines]) - lines.append(" ...") - lines.append("") - - return textwrap.dedent("\n".join(lines)) - - def _generate_interface_from_functions( - self, funcs: List[Callable], interface_name: str, types_module: str - ) -> str: - lines = [ - "# Auto-generated class interface", - "from typing import * # type: ignore", - "from abc import ABC, abstractmethod", - f"from {types_module} import *", - "", - ] - - lines.append(f"class {interface_name}(ABC):") # Abstract class - lines.append("") - - indent = " " * 4 - if not funcs: - lines.append(f"{indent}pass") - else: - for func in funcs: - # Add method docstring and signature - lines.append(f"{indent}@abstractmethod") - method_lines = self._get_function_with_docstring( - func, _get_type_name(func) - ) - lines.extend([line if line else "" for line in method_lines]) - lines.append(f"{indent}{indent}...") - lines.append("") - - if any(["Decimal" in line for line in lines]): - lines.insert(2, "from decimal import Decimal") - - return "\n".join(lines) - - def _generate_impl_from_functions( - self, - funcs: List[Callable], - class_name: str, - interface_module_name: str, - interface_name: str, - types_module: str, - ) -> str: - lines = [ - "# Auto-generated class", - "from typing import *", - "from abc import ABC, abstractmethod", - f"from {interface_module_name} import {interface_name}", - f"from {types_module} import *", - "", - """class IToolInvoker(ABC): - T = TypeVar("T") - @abstractmethod - def invoke(self, toolname: str, arguments: Dict[str, Any], model: Type[T])->T: - ...""", - "", - ] - - lines.append(f"class {class_name}({interface_name}):") # class - lines.append("") - lines.append(""" def __init__(self, delegate: IToolInvoker): - self._delegate = delegate - """) - - if not funcs: - lines.append(" pass") - else: - for func in funcs: - # Add method docstring and signature - method_lines = self._get_function_with_docstring( - func, _get_type_name(func) - ) - lines.extend([line if line else "" for line in method_lines]) - lines.extend(self._generate_delegate_code(func)) - lines.append("") - - if any(["Decimal" in line for line in lines]): - lines.insert(2, "from decimal import Decimal") - - return "\n".join(lines) - - def _generate_delegate_code(self, func: Callable) -> List[str]: - func_name = _get_type_name(func) - indent = " " * 4 * 2 - sig = inspect.signature(func) - ret = sig.return_annotation - if ret is inspect._empty: - ret_name = "None" - elif hasattr(ret, "__name__"): - ret_name = ret.__name__ - else: - ret_name = str(ret) - return [ - indent + "args = {k: v for k, v in locals().items() if k != 'self'}", - indent + f"return self._delegate.invoke('{func_name}', args, {ret_name})", - ] - - def _get_function_with_docstring( - self, func: FunctionType, func_name: str - ) -> List[str]: - """Extract method signature with type hints and docstring.""" - lines = [] - - # Get method signature - method_signature = self._get_method_signature(func, func_name) - lines.append(f" {method_signature}:") - - # Add method docstring if available - if func.__doc__: - docstring = func.__doc__ - indent = " " * 8 - if docstring: - lines.append(indent + '"""') - lines.extend(docstring.strip("\n").split("\n")) - lines.append(indent + '"""') - - return lines - - def should_include_type(self, typ: type) -> bool: - if hasattr(typ, "__module__"): - module_root = typ.__module__.split(".")[0] - if module_root in self.include_module_roots: - return True - return any([self.should_include_type(arg) for arg in get_args(typ)]) - - def _generate_class_definition(self, typ: type) -> List[str]: - """Generate a class definition with its fields.""" - lines = [] - class_name = _get_type_name(typ) - - if is_dataclass(typ): - lines.append("@dataclass") - - # Determine base classes - bases = [_get_type_name(b) for b in _get_type_bases(typ)] - inheritance = f"({', '.join(bases)})" if bases else "" - lines.append(f"class {class_name}{inheritance}:") - - # #is Pydantic? - # is_pydantic = False - # for base in cls.__bases__: - # if hasattr(base, '__module__') and 'pydantic' in str(base.__module__): - # is_pydantic = True - - indent = " " * 4 - # Add class docstring if available - if typ.__doc__: - docstring = typ.__doc__ - if docstring: - lines.append(f'{indent}"""') - lines.extend( - [f"{indent}{line}" for line in docstring.strip("\n").split("\n")] - ) - lines.append(f'{indent}"""') - - # Fields - annotations = getattr(typ, "__annotations__", {}) - if annotations: - field_descriptions = self._extract_field_descriptions(typ) - for field_name, field_type in annotations.items(): - if field_name.startswith("_"): - continue - - # Handle optional field detection by default=None - is_optional = False - default_val = getattr(typ, field_name, ...) - if default_val is None: - is_optional = True - elif hasattr(typ, "__fields__"): - # Pydantic field with default=None - field_info = typ.__fields__.get(field_name) - if field_info and field_info.is_required() is False: - is_optional = True - - type_str = self._format_type(field_type) - - # Avoid wrapping Optional twice - if is_optional: - origin = get_origin(field_type) - args = get_args(field_type) - already_optional = ( - origin is typing.Union - and type(None) in args - or type_str.startswith("Optional[") - ) - if not already_optional: - type_str = f"Optional[{type_str}]" - - # Check if we have a description for this field - description = field_descriptions.get(field_name) - - # if description and is_pydantic: - # # Use Pydantic Field with description - # lines.append(f' {field_name}: {type_str} = Field(description="{description}")') - if description: - # Add description as comment for non-Pydantic classes - lines.append(f"{indent}{field_name}: {type_str} # {description}") - else: - # No description available - lines.append(f"{indent}{field_name}: {type_str}") - - # Enum - elif issubclass(typ, Enum): - if issubclass(typ, str): - lines.extend( - [f'{indent}{entry.name} = "{entry.value}"' for entry in typ] - ) - else: - lines.extend([f"{indent}{entry.name} = {entry.value}" for entry in typ]) - - else: - lines.append(f"{indent}pass") - - return lines - - def _extract_field_descriptions(self, typ: type) -> Dict[str, str]: - """Extract field descriptions from various sources.""" - descriptions = {} - - # Method 1: Check for Pydantic Field definitions - if hasattr(typ, "__fields__"): # Pydantic v1 - for field_name, field_info in typ.__fields__.items(): - if hasattr(field_info, "field_info") and hasattr( - field_info.field_info, "description" - ): - descriptions[field_name] = field_info.field_info.description - elif hasattr(field_info, "description") and field_info.description: - descriptions[field_name] = field_info.description - - # Method 2: Check for Pydantic v2 model fields - if hasattr(typ, "model_fields"): # Pydantic v2 - for field_name, field_info in typ.model_fields.items(): - if hasattr(field_info, "description") and field_info.description: - descriptions[field_name] = field_info.description - - # Method 3: Check class attributes for Field() definitions - for attr_name in dir(typ): - if not attr_name.startswith("_"): - try: - attr_value = getattr(typ, attr_name) - # Check if it's a Pydantic Field - if hasattr(attr_value, "description") and attr_value.description: - descriptions[attr_name] = attr_value.description - elif hasattr(attr_value, "field_info") and hasattr( - attr_value.field_info, "description" - ): - descriptions[attr_name] = attr_value.field_info.description - except Exception: - pass - - # Method 4: Parse class source for inline comments or docstrings - try: - source_lines = inspect.getsourcelines(typ)[0] - current_field = None - - for line in source_lines: - line = line.strip() - - # Look for field definitions with type hints - if ( - ":" in line - and not line.startswith("def ") - and not line.startswith("class ") - ): - # Extract field name - field_part = line.split(":")[0].strip() - if " " not in field_part and field_part.isidentifier(): - current_field = field_part - - # Look for comments on the same line or next line - if "#" in line and current_field: - comment = line.split("#", 1)[1].strip() - if comment and current_field not in descriptions: - descriptions[current_field] = comment - current_field = None - - except Exception: - pass - - # Method 5: Check for dataclass field descriptions - if hasattr(typ, "__dataclass_fields__"): - for field_name, field in typ.__dataclass_fields__.items(): - if hasattr(field, "metadata") and "description" in field.metadata: - descriptions[field_name] = field.metadata["description"] - - return descriptions - - def _get_method_signature(self, method: FunctionType, method_name: str): - """Extract method signature with type hints.""" - try: - sig = inspect.signature(method) - # Get param hints - try: - param_hints = get_type_hints(method) - except Exception: - param_hints = {} - - params = [] - if not sig.parameters.get("self"): - params.append("self") - - for param_name, param in sig.parameters.items(): - param_str = param_name - - # Add type annotation if available - if param_name in param_hints: - type_str = self._format_type(param_hints[param_name]) - param_str += f": {type_str}" - elif param.annotation != param.empty: - param_str += f": {param.annotation}" - - # Add default value if present - if param.default != param.empty: - if isinstance(param.default, str): - param_str += f' = "{param.default}"' - else: - param_str += f" = {repr(param.default)}" - - params.append(param_str) - - # Handle return type - return_annotation = "" - if "return" in param_hints: - if param_hints["return"] is not type(None): - return_type = self._format_type(param_hints["return"]) - return_annotation = f" -> {return_type}" - elif sig.return_annotation != sig.empty: - return_annotation = f" -> {sig.return_annotation}" - - params_str = ", ".join(params) - return f"def {method_name}({params_str}){return_annotation}" - - except Exception: - # Fallback for problematic signatures - return f"def {method_name}(self, *args, **kwargs)" - - def _collect_all_types_from_functions( - self, funcs: List[Callable] - ) -> Tuple[Set[type], Dependencies]: - processed_types = set() - collected = set() - dependencies = defaultdict(set) - - for func in funcs: - for param, hint in get_type_hints(func).items(): # noqa: B007 - self._collect_types_recursive( - hint, processed_types, collected, dependencies - ) - - return collected, dependencies - - def _collect_all_types_from_class( - self, typ: type - ) -> Tuple[Set[type], Dependencies]: - """Collect all types used in the class recursively.""" - visited = set() - collected = set() - dependencies = defaultdict(set) - - # Field types - try: - class_hints = get_type_hints(typ) - for field, hint in class_hints.items(): # noqa: B007 - self._collect_types_recursive(hint, visited, collected, dependencies) - except Exception: - pass - - # Methods and param types - for name, method in inspect.getmembers(typ, predicate=inspect.isfunction): # noqa: B007 - try: - method_hints = get_type_hints(method) - for hint in method_hints.values(): - self._collect_types_recursive( - hint, visited, collected, dependencies - ) - except Exception: - pass - - # Also collect base class types - for base in _get_type_bases(typ): - self._collect_types_recursive(base, visited, collected, dependencies) - - return collected, dependencies - - def _collect_types_recursive( - self, typ: type, visited: Set[type], acc: Set[type], dependencies: Dependencies - ): - """Recursively collect all types from a type hint.""" - visited.add(typ) - - if not self.should_include_type(typ): - return - - acc.add(typ) - origin = get_origin(typ) - args = get_args(typ) - - # Type with generic arguments. eg: List[Person] - if origin and args: - for f_arg in args: - self._collect_types_recursive(f_arg, visited, acc, dependencies) - self._add_dependency(typ, f_arg, dependencies) - return - - # If it's a custom class, try to get its type hints - try: - field_hints = typ.__annotations__ # direct fields - for field_name, field_hint in field_hints.items(): # noqa: B007 - f_origin = get_origin(field_hint) - if f_origin: - for f_arg in get_args(field_hint): - self._collect_types_recursive(f_arg, visited, acc, dependencies) - self._add_dependency(typ, f_arg, dependencies) - else: - self._collect_types_recursive( - field_hint, visited, acc, dependencies - ) - self._add_dependency(typ, field_hint, dependencies) - - for base in _get_type_bases(typ): # Base classes - self._collect_types_recursive(base, visited, acc, dependencies) - self._add_dependency(typ, base, dependencies) - except Exception: - pass - - def _add_dependency( - self, dependent_type: type, dependency_type: type, dependencies: Dependencies - ): - """Add a dependency relationship between types.""" - dep_name = _get_type_name(dependent_type) - dep_on_name = _get_type_name(dependency_type) - if dep_name != dep_on_name: - dependencies[dependent_type].add(dependency_type) - - for arg in get_args(dependency_type): - dependencies[dependent_type].add(arg) - - def _topological_sort_types(self, types: List[type], dependencies: Dependencies): - """Sort types by their dependencies using topological sort.""" - # Create a mapping of type names to types for easier lookup - type_map = {_get_type_name(t): t for t in types} - - # Build adjacency list and in-degree count - adj_list = defaultdict(list) - in_degree = defaultdict(int) - - # Initialize in-degree for all types - for t in types: - type_name = _get_type_name(t) - if type_name not in in_degree: - in_degree[type_name] = 0 - - # Build the dependency graph - for dependent_type in types: - dependent_name = _get_type_name(dependent_type) - for dependency_type in dependencies.get(dependent_type, set()): - dependency_name = _get_type_name(dependency_type) - if ( - dependency_name in type_map - ): # Only consider types we're actually processing - adj_list[dependency_name].append(dependent_name) - in_degree[dependent_name] += 1 - - # Kahn's algorithm for topological sorting - queue = deque([name for name in in_degree if in_degree[name] == 0]) - result = [] - - while queue: - current = queue.popleft() - result.append(type_map[current]) - - for neighbor in adj_list[current]: - in_degree[neighbor] -= 1 - if in_degree[neighbor] == 0: - queue.append(neighbor) - - # If we couldn't sort all types, there might be circular dependencies - # Add remaining types at the end - sorted_names = {_get_type_name(t) for t in result} - remaining = [t for t in types if _get_type_name(t) not in sorted_names] - result.extend(remaining) - - return result - - def _generate_types_file( - self, collected_types: Set[type], dependencies: Dependencies - ) -> str: - """Generate the types file content.""" - lines = [] - lines.append("# Auto-generated type definitions") - lines.append("from datetime import date, datetime") - lines.append("from enum import Enum") - lines.append("from typing import *") - lines.append("from pydantic import BaseModel, Field") - lines.append("from dataclasses import dataclass") - lines.append("") - - custom_classes = [] - for typ in collected_types: - # Check if it's a class with attributes - if hasattr(typ, "__annotations__") or ( - hasattr(typ, "__dict__") - and any( - not callable(getattr(typ, attr, None)) - for attr in dir(typ) - if not attr.startswith("_") - ) - ): - custom_classes.append(typ) - custom_classes = self._topological_sort_types(custom_classes, dependencies) - - # Generate custom classes (sorted by dependency) - for cls in custom_classes: - class_def = self._generate_class_definition(cls) - if class_def: # Only add non-empty class definitions - lines.extend(class_def) - lines.append("") - - if any(["Decimal" in line for line in lines]): - lines.insert(2, "from decimal import Decimal") - - return "\n".join(lines) - - def _format_type(self, typ: type) -> str: - if typ is None: - return "Any" - - # Unwrap Annotated[T, ...] - origin = get_origin(typ) - if origin is Annotated: - typ = get_args(typ)[0] - origin = get_origin(typ) - - # Literal - if origin is Literal: - args = get_args(typ) - literals = ", ".join(repr(a) for a in args) - return f"Literal[{literals}]" - - # Union (Optional or other) - if origin is Union: - args = get_args(typ) - non_none = [a for a in args if a is not type(None)] - if len(non_none) == 1: - return f"Optional[{self._format_type(non_none[0])}]" - else: - inner = ", ".join(self._format_type(a) for a in args) - return f"Union[{inner}]" - - if origin is UnionType: - args = get_args(typ) - return "| ".join(self._format_type(a) for a in args) - - # Generic containers - if origin: - args = get_args(typ) - inner = ", ".join(self._format_type(a) for a in args) - if inner: - return f"{_get_type_name(origin)}[{inner}]" - return _get_type_name(origin) - - # Simple type - return _get_type_name(origin or typ) - - -def _get_type_name(typ) -> str: - """Get a consistent name for a type object.""" - if hasattr(typ, "__name__"): - return typ.__name__ - return str(typ) - - -def _get_type_bases(typ: type) -> List[type]: - if hasattr(typ, "__bases__"): - return typ.__bases__ # type: ignore - return [] - - -def _is_global_or_class_function(func): - if not callable(func): - return False - - # Reject lambdas - if _get_type_name(func) == "": - return False - - # Static methods and global functions are of type FunctionType - if isinstance(func, types.FunctionType): - return True - - # Class methods are MethodType but have __self__ as the class, not instance - if isinstance(func, types.MethodType): - if inspect.isclass(func.__self__): - return True # classmethod - else: - return False # instance method - - return False - - -# Example class for testing -if __name__ == "__main__": - # from airline2.domains.airline.tools import AirlineTools - # extractor = APIExtractor("output", include_module_roots = ["airline2"]) - # # interface_file, types_file = extractor.extract_from_class(AirlineTools, output_dir="output") - # t = AirlineTools({}) - # interface_file, types_file = extractor.extract_from_functions([AirlineTools.book_reservation], "I_Tau", "my_app.i_tau", "my_app.tau_types") - - from appointment_app import lg_tools - - funcs = [ - lg_tools.add_user, - lg_tools.add_payment_method, - lg_tools.get_user_payment_methods, - lg_tools.get_available_dr_specialties, - lg_tools.search_doctors, - lg_tools.search_available_appointments, - ] - extractor = APIExtractor("output", include_module_roots=["appointment_app"]) - interface, types_, impl = extractor.extract_from_functions( - funcs, - "I_Clinic", - "clinic.i_clinic", - "clinic.clinc_types", - "clinic.clinic_impl", - "ClinicImpl", - ) - - print(f"Interface saved to: {interface.file_name}") - print(f"Types saved to: {types_.file_name}") - print(f"Impl saved to: {impl.file_name}") - print("Done") diff --git a/altk/pre_tool/toolguard/toolguard/gen_py/consts.py b/altk/pre_tool/toolguard/toolguard/gen_py/consts.py deleted file mode 100644 index fa14d69..0000000 --- a/altk/pre_tool/toolguard/toolguard/gen_py/consts.py +++ /dev/null @@ -1,35 +0,0 @@ -from altk.pre_tool.toolguard.toolguard.common.str import to_snake_case -from altk.pre_tool.toolguard.toolguard.data_types import ToolPolicy, ToolPolicyItem - - -RUNTIME_PACKAGE_NAME = "rt_toolguard" -RUNTIME_INIT_PY = "__init__.py" -RUNTIME_TYPES_PY = "data_types.py" -RUNTIME_APP_TYPES_PY = "domain_types.py" - -PY_ENV = "my_env" -PY_PACKAGES = ["pydantic", "pytest"] # , "litellm"] - - -def guard_fn_name(tool_policy: ToolPolicy) -> str: - return to_snake_case(f"guard_{tool_policy.tool_name}") - - -def guard_fn_module_name(tool_policy: ToolPolicy) -> str: - return to_snake_case(f"guard_{tool_policy.tool_name}") - - -def guard_item_fn_name(tool_item: ToolPolicyItem) -> str: - return to_snake_case(f"guard_{tool_item.name}") - - -def guard_item_fn_module_name(tool_item: ToolPolicyItem) -> str: - return to_snake_case(f"guard_{tool_item.name}") - - -def test_fn_name(tool_item: ToolPolicyItem) -> str: - return to_snake_case(f"test_guard_{tool_item.name}") - - -def test_fn_module_name(tool_item: ToolPolicyItem) -> str: - return to_snake_case(f"test_guard_{tool_item.name}") diff --git a/altk/pre_tool/toolguard/toolguard/gen_py/domain_from_funcs.py b/altk/pre_tool/toolguard/toolguard/gen_py/domain_from_funcs.py deleted file mode 100644 index f815108..0000000 --- a/altk/pre_tool/toolguard/toolguard/gen_py/domain_from_funcs.py +++ /dev/null @@ -1,56 +0,0 @@ -import os -from pathlib import Path -from typing import Callable, List -from os.path import join - -from altk.pre_tool.toolguard.toolguard.gen_py.api_extractor import APIExtractor -from altk.pre_tool.toolguard.toolguard.common.str import to_camel_case, to_snake_case -from altk.pre_tool.toolguard.toolguard.data_types import FileTwin, RuntimeDomain -import altk.pre_tool.toolguard.toolguard.gen_py.consts as consts - - -def generate_domain_from_functions( - py_path: str, app_name: str, funcs: List[Callable], include_module_roots: List[str] -) -> RuntimeDomain: - # ToolGuard Runtime - os.makedirs(join(py_path, consts.RUNTIME_PACKAGE_NAME), exist_ok=True) - - root = str(Path(__file__).parent.parent) - common = FileTwin.load_from(root, "data_types.py").save_as( - py_path, join(consts.RUNTIME_PACKAGE_NAME, consts.RUNTIME_TYPES_PY) - ) - runtime = FileTwin.load_from(root, "runtime.py") - runtime.content = runtime.content.replace( - "toolguard.", f"{consts.RUNTIME_PACKAGE_NAME}." - ) - runtime.save_as(py_path, join(consts.RUNTIME_PACKAGE_NAME, consts.RUNTIME_INIT_PY)) - - # APP init and Types - os.makedirs(join(py_path, to_snake_case(app_name)), exist_ok=True) - FileTwin(file_name=join(to_snake_case(app_name), "__init__.py"), content="").save( - py_path - ) - - extractor = APIExtractor(py_path=py_path, include_module_roots=include_module_roots) - api_cls_name = f"I_{to_camel_case(app_name)}" - impl_module_name = to_snake_case(f"{app_name}.{app_name}_impl") - impl_class_name = to_camel_case(f"{app_name}_Impl") - api, types, impl = extractor.extract_from_functions( - funcs, - interface_name=api_cls_name, - interface_module_name=to_snake_case(f"{app_name}.i_{app_name}"), - types_module_name=to_snake_case(f"{app_name}.{app_name}_types"), - impl_module_name=impl_module_name, - impl_class_name=impl_class_name, - ) - - return RuntimeDomain( - app_name=app_name, - toolguard_common=common, - app_types=types, - app_api_class_name=api_cls_name, - app_api=api, - app_api_impl_class_name=impl_class_name, - app_api_impl=impl, - app_api_size=len(funcs), - ) diff --git a/altk/pre_tool/toolguard/toolguard/gen_py/domain_from_openapi.py b/altk/pre_tool/toolguard/toolguard/gen_py/domain_from_openapi.py deleted file mode 100644 index f19e930..0000000 --- a/altk/pre_tool/toolguard/toolguard/gen_py/domain_from_openapi.py +++ /dev/null @@ -1,241 +0,0 @@ -import os -from pathlib import Path -from typing import List, Optional, Tuple, Union -from os.path import join - -from altk.pre_tool.toolguard.toolguard.common.array import find -from altk.pre_tool.toolguard.toolguard.common.py import module_to_path -from altk.pre_tool.toolguard.toolguard.common.str import to_camel_case, to_snake_case -import altk.pre_tool.toolguard.toolguard.gen_py.consts as consts -from altk.pre_tool.toolguard.toolguard.gen_py.templates import load_template -from altk.pre_tool.toolguard.toolguard.gen_py.utils.datamodel_codegen import ( - run as dm_codegen, -) -from altk.pre_tool.toolguard.toolguard.common.open_api import ( - OpenAPI, - Operation, - Parameter, - ParameterIn, - PathItem, - Reference, - RequestBody, - Response, - JSchema, - read_openapi, -) -from altk.pre_tool.toolguard.toolguard.data_types import FileTwin, RuntimeDomain - - -def generate_domain_from_openapi( - py_path: str, app_name: str, openapi_file: str -) -> RuntimeDomain: - # ToolGuard Runtime - os.makedirs(join(py_path, consts.RUNTIME_PACKAGE_NAME), exist_ok=True) - - root = str(Path(__file__).parent.parent) - common = FileTwin.load_from(root, "data_types.py").save_as( - py_path, join(consts.RUNTIME_PACKAGE_NAME, consts.RUNTIME_TYPES_PY) - ) - runtime = FileTwin.load_from(root, "runtime.py") - runtime.content = runtime.content.replace( - "toolguard.", f"{consts.RUNTIME_PACKAGE_NAME}." - ) - runtime.save_as(py_path, join(consts.RUNTIME_PACKAGE_NAME, consts.RUNTIME_INIT_PY)) - - # APP Types - oas = read_openapi(openapi_file) - os.makedirs(join(py_path, to_snake_case(app_name)), exist_ok=True) - - types_name = f"{app_name}_types" - types_module_name = f"{app_name}.{types_name}" - types = FileTwin( - file_name=module_to_path(types_module_name), content=dm_codegen(openapi_file) - ).save(py_path) - - # APP Init - FileTwin( - file_name=join(to_snake_case(app_name), "__init__.py"), - content=f"from . import {types_name}", - ).save(py_path) - - # APP API - api_cls_name = to_camel_case("I " + app_name) - methods = _get_oas_methods(oas) - api_module_name = to_snake_case(f"{app_name}.i_{app_name}") - api = FileTwin( - file_name=module_to_path(api_module_name), - content=_generate_api(methods, api_cls_name, types_module_name), - ).save(py_path) - - # APP API Impl - impl_cls_name = to_camel_case(app_name + " impl") - impl_module_name = to_snake_case(f"{app_name}.{app_name}_impl") - cls_str = _generate_api_impl( - methods, api_module_name, types_module_name, api_cls_name, impl_cls_name - ) - api_impl = FileTwin( - file_name=module_to_path(impl_module_name), content=cls_str - ).save(py_path) - - return RuntimeDomain( - app_name=app_name, - toolguard_common=common, - app_types=types, - app_api_class_name=api_cls_name, - app_api=api, - app_api_impl_class_name=impl_cls_name, - app_api_impl=api_impl, - app_api_size=len(methods), - ) - - -def _get_oas_methods(oas: OpenAPI): - methods = [] - for path, path_item in oas.paths.items(): # noqa: B007 - path_item = oas.resolve_ref(path_item, PathItem) - assert path_item - for mtd, op in path_item.operations.items(): # noqa: B007 - op = oas.resolve_ref(op, Operation) - if not op: - continue - params = (path_item.parameters or []) + (op.parameters or []) - params = [oas.resolve_ref(p, Parameter) for p in params] - args, ret = _make_signature(op, params, oas) # type: ignore - args_str = ", ".join(["self"] + [f"{arg}:{type}" for arg, type in args]) - sig = f"({args_str})->{ret}" - - body = "pass" - # if orign_funcs: - # func = find(orign_funcs or [], lambda fn: fn.__name__ == op.operationId) # type: ignore - # if func: - # body = _call_fn_body(func) - methods.append( - { - "name": to_snake_case(op.operationId), # type: ignore - "signature": sig, - "doc": op.description, - "body": body, - } - ) - return methods - - -# def _call_fn_body(func:Callable): -# module = inspect.getmodule(func) -# if module is None or not hasattr(module, '__file__'): -# raise ValueError("Function must be from an importable module") - -# module_name = module.__name__ -# qualname = func.__qualname__ -# func_name = func.__name__ -# parts = qualname.split('.') - -# if len(parts) == 1: # Regular function -# return f""" -# mod = importlib.import_module("{module_name}") -# func = getattr(mod, "{func_name}") -# return func(locals())""" - -# if len(parts) == 2: # Classmethod or staticmethod -# class_name = parts[0] -# return f""" -# mod = importlib.import_module("{module_name}") -# cls = getattr(mod, "{class_name}") -# func = getattr(cls, "{func_name}") -# return func(locals())""" - -# if len(parts) > 2: # Instance method -# class_name = parts[-2] -# return f""" -# mod = importlib.import_module("{module_name}") -# cls = getattr(mod, "{class_name}") -# instance = cls() -# func = getattr(instance, "{func_name}") -# return func(locals())""" -# raise NotImplementedError("Unsupported function type or nested depth") - - -def _generate_api(methods: List, cls_name: str, types_module: str) -> str: - return load_template("api.j2").render( - types_module=types_module, class_name=cls_name, methods=methods - ) - - -def _generate_api_impl( - methods: List, api_module: str, types_module: str, api_cls_name: str, cls_name: str -) -> str: - return load_template("api_impl.j2").render( - api_cls_name=api_cls_name, - types_module=types_module, - api_module=api_module, - class_name=cls_name, - methods=methods, - ) - - -def _make_signature( - op: Operation, params: List[Parameter], oas: OpenAPI -) -> Tuple[Tuple[str, str], str]: - fn_name = to_camel_case(op.operationId) - args = [] - - for param in params: - if param.in_ == ParameterIn.path: - args.append((param.name, _oas_to_py_type(param.schema_, oas) or "Any")) - - if find(params, lambda p: p.in_ == ParameterIn.query): - query_type = f"{fn_name}ParametersQuery" - args.append(("args", query_type)) - - req_body = oas.resolve_ref(op.requestBody, RequestBody) - if req_body: - scm_or_ref = req_body.content_json.schema_ - body_type = _oas_to_py_type(scm_or_ref, oas) - if body_type is None: - body_type = f"{fn_name}Request" - args.append(("args", body_type)) - - rsp_or_ref = op.responses.get("200") - rsp = oas.resolve_ref(rsp_or_ref, Response) - if rsp: - scm_or_ref = rsp.content_json.schema_ - if scm_or_ref: - rsp_type = _oas_to_py_type(scm_or_ref, oas) - if rsp_type is None: - rsp_type = f"{fn_name}Response" - - return args, rsp_type - - -def _oas_to_py_type(scm_or_ref: Union[Reference, JSchema], oas: OpenAPI) -> str | None: - if isinstance(scm_or_ref, Reference): - return scm_or_ref.ref.split("/")[-1] - - scm = oas.resolve_ref(scm_or_ref, JSchema) - if scm: - py_type = _primitive_jschema_types_to_py(scm.type, scm.format) - if py_type: - return py_type - # if scm.type == JSONSchemaTypes.array and scm.items: - # return f"List[{oas_to_py_type(scm.items, oas) or 'Any'}]" - - -def _primitive_jschema_types_to_py( - type: Optional[str], format: Optional[str] -) -> Optional[str]: - # https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.0.2.md#data-types - if type == "string": - if format == "date": - return "datetime.date" - if format == "date-time": - return "datetime.datetime" - if format in ["byte", "binary"]: - return "bytes" - return "str" - if type == "integer": - return "int" - if type == "number": - return "float" - if type == "boolean": - return "bool" - return None diff --git a/altk/pre_tool/toolguard/toolguard/gen_py/gen_toolguards.py b/altk/pre_tool/toolguard/toolguard/gen_py/gen_toolguards.py deleted file mode 100644 index 921d65f..0000000 --- a/altk/pre_tool/toolguard/toolguard/gen_py/gen_toolguards.py +++ /dev/null @@ -1,123 +0,0 @@ -import asyncio -import json -import logging -import os -from os.path import join -from typing import Callable, List, Literal, Optional, cast - -import mellea - -import altk.pre_tool.toolguard.toolguard.gen_py.consts as consts -from altk.pre_tool.toolguard.toolguard.gen_py.domain_from_funcs import ( - generate_domain_from_functions, -) -from altk.pre_tool.toolguard.toolguard.data_types import RuntimeDomain, ToolPolicy -from altk.pre_tool.toolguard.toolguard.gen_py.domain_from_openapi import ( - generate_domain_from_openapi, -) -from altk.pre_tool.toolguard.toolguard.runtime import ToolGuardsCodeGenerationResult -from altk.pre_tool.toolguard.toolguard.gen_py.tool_guard_generator import ( - ToolGuardGenerator, -) -import altk.pre_tool.toolguard.toolguard.gen_py.utils.pytest as pytest -import altk.pre_tool.toolguard.toolguard.gen_py.utils.venv as venv -import altk.pre_tool.toolguard.toolguard.gen_py.utils.pyright as pyright -from altk.pre_tool.toolguard.toolguard.common.py import unwrap_fn - -logger = logging.getLogger(__name__) - -ENV_GENPY_BACKEND_NAME = "TOOLGUARD_GENPY_BACKEND_NAME" -ENV_GENPY_MODEL_ID = "TOOLGUARD_GENPY_MODEL_ID" -ENV_GENPY_ARGS = "TOOLGUARD_GENPY_ARGS" - - -async def generate_toolguards_from_functions( - app_name: str, - tool_policies: List[ToolPolicy], - py_root: str, - funcs: List[Callable], - module_roots: Optional[List[str]] = None, -) -> ToolGuardsCodeGenerationResult: - assert funcs, "Funcs cannot be empty" - logger.debug(f"Starting... will save into {py_root}") - - if not module_roots: - if len(funcs) > 0: - module_roots = list( - {unwrap_fn(func).__module__.split(".")[0] for func in funcs} - ) - assert module_roots - - # Domain from functions - domain = generate_domain_from_functions(py_root, app_name, funcs, module_roots) - return await generate_toolguards_from_domain( - app_name, tool_policies, py_root, domain - ) - - -async def generate_toolguards_from_openapi( - app_name: str, tool_policies: List[ToolPolicy], py_root: str, openapi_file: str -) -> ToolGuardsCodeGenerationResult: - logger.debug(f"Starting... will save into {py_root}") - - # Domain from OpenAPI - domain = generate_domain_from_openapi(py_root, app_name, openapi_file) - return await generate_toolguards_from_domain( - app_name, tool_policies, py_root, domain - ) - - -def start_melea_session() -> mellea.MelleaSession: - backend_name = cast( - Literal["ollama", "hf", "openai", "watsonx", "litellm"], - os.getenv(ENV_GENPY_BACKEND_NAME, "openai"), - ) - - model_id = os.getenv(ENV_GENPY_MODEL_ID) - assert model_id, f"'{ENV_GENPY_MODEL_ID}' environment variable not set" - - kw_args_val = os.getenv(ENV_GENPY_ARGS) - kw_args = {} - if kw_args_val: - try: - kw_args = json.loads(kw_args_val) - except Exception as e: - logger.warning( - f"Failed to parse {ENV_GENPY_ARGS}: {e}. Using empty dict instead." - ) - - print(kw_args) - - return mellea.start_session(backend_name=backend_name, model_id=model_id, **kw_args) - - -async def generate_toolguards_from_domain( - app_name: str, tool_policies: List[ToolPolicy], py_root: str, domain: RuntimeDomain -) -> ToolGuardsCodeGenerationResult: - # Setup env - venv.run(join(py_root, consts.PY_ENV), consts.PY_PACKAGES) - pyright.config(py_root) - pytest.configure(py_root) - - with start_melea_session(): - # tools - tools_w_poilicies = [ - tool_policy - for tool_policy in tool_policies - if len(tool_policy.policy_items) > 0 - ] - tool_results = await asyncio.gather( - *[ - ToolGuardGenerator( - app_name, tool_policy, py_root, domain, consts.PY_ENV - ).generate() - for tool_policy in tools_w_poilicies - ] - ) - - tools_result = { - tool.tool_name: res for tool, res in zip(tools_w_poilicies, tool_results) - } - return ToolGuardsCodeGenerationResult( - root_dir=py_root, domain=domain, tools=tools_result - ).save(py_root) diff --git a/altk/pre_tool/toolguard/toolguard/gen_py/prompts/__init__.py b/altk/pre_tool/toolguard/toolguard/gen_py/prompts/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/altk/pre_tool/toolguard/toolguard/gen_py/prompts/gen_tests.py b/altk/pre_tool/toolguard/toolguard/gen_py/prompts/gen_tests.py deleted file mode 100644 index 80a1724..0000000 --- a/altk/pre_tool/toolguard/toolguard/gen_py/prompts/gen_tests.py +++ /dev/null @@ -1,139 +0,0 @@ -from typing import List -from altk.pre_tool.toolguard.toolguard.data_types import ( - Domain, - FileTwin, - ToolPolicyItem, -) -from mellea import generative - -# from toolguard.gen_py.prompts.python_code import PythonCodeModel - - -@generative -def generate_init_tests( - fn_src: FileTwin, - policy_item: ToolPolicyItem, - domain: Domain, - dependent_tool_names: List[str], -) -> str: - """ - Generate Python unit tests for a function to verify tool-call compliance with policy constraints. - - Args: - fn_src (FileTwin): Source code containing the function-under-test signature. - policy_item (ToolPolicyItem): Specification of the function-under-test, including positive and negative examples. - domain (Domain): available data types and interfaces the test can use. - dependent_tool_names (List[str]): other tool names that this tool depends on. - - Returns: - str: Generated Python unit test code. - - This function creates unit tests to validate the behavior of a given function-under-test. - The function-under-test checks the argument data, and raise an exception if they violated the requirements in the policy item. - - Test Generation Rules: - - Make sure to Python import all items in fn_src, common and domain modules. - - A `policy_item` has multiple `compliance_examples` and `violation_examples` examples. - - For each `compliance_examples`, ONE test method is generated. - - For each `violation_examples`, ONE test method is generated. - - The function-under-test is EXPECTED to raise a `PolicyViolationException`. - - use `with pytest.raises(PolicyViolationException): function_under_test()` to expect for exceptions. - - - Test class and method names should be meaningful and use up to **six words in snake_case**. - - For each test, add a comment quoting the policy item case that this function is testing - - Failure message should describe the test scenario that failed, the expected and the actual outcomes. - - Data population and references: - - For compliance examples, populate all fields. - - For collections (arrays, dict and sets) populate at least one item. - - You should mock the return_value from ALL tools listed in `dependent_tool_names`. - - Use `side_effect` to return the expected value only when the expected parameters are passed. - - For time dependent attributes, compute the timestamp dynamically (avoid hardcoded times). - - for example, to set a timestamp occurred 24 hours ago, use something like: `created_at = (datetime.now() - timedelta(hours=24)).strftime("%Y-%m-%dT%H:%M:%S")`. - - import the required date and time libraries. for example: `from datetime import datetime, timedelta` - - If you have a choice passing a plain a Pydantic model or a `Dictionary`, prefer Pydantic. - - Example: - * fn_src: - ```python - # file: my_app/guard_create_reservation.py - def guard_create_reservation(api: SomeAPI, user_id: str, hotel_id: str, reservation_date: str, persons: int): - ... - ``` - * policy_item.description = "cannot book a room for a date in the past" - * policy_item.violation_examples = ["book a room for a hotel, one week ago"] - * Dependent_tool_names: `["get_user", "get_hotel"]` - * Domain: - ```python - # file: my_app/api.py - class SomeAPI(ABC): - def get_user(self, user_id): - ... - def get_hotel(self, hotel_id): - ... - def create_reservation(self, user_id: str, hotel_id: str, reservation_date: str, persons: int): - \"\"\" - Args: - ... - reservation_date: check in date, in `YYYY-MM-DDTHH:MM:SS` format - \"\"\" - ... - ``` - - Should return this snippet: - ```python - from unittest.mock import MagicMock - import pytest - from toolguard.data_types import PolicyViolationException - from my_app.guard_create_reservation import guard_create_reservation - from my_app.api import * - - def test_violation_book_room_in_the_past(): - \"\"\" - Policy: "cannot book room for a date in the past" - Example: "book a room for a hotel, one week ago" - \"\"\" - - # mock other tools function return values - user = User(user_id="123", ...) - hotel = Hotel(hotel_id="789", ...) - - api = MagicMock(spec=SomeAPI) - api.get_user.side_effect = lambda user_id: user if user_id == "123" else None - api.get_hotel.side_effect = lambda hotel_id: hotel if hotel_id == "789" else None - - #invoke function under test. - with pytest.raises(PolicyViolationException): - next_week = (datetime.now() + timedelta(days=7)).strftime("%Y-%m-%dT%H:%M:%S") - guard_create_reservation(api, user_id="123", hotel_id="789", reservation_date=next_week, persons=3) - ``` - """ - ... - - -@generative -def improve_tests( - prev_impl: str, - domain: Domain, - policy_item: ToolPolicyItem, - review_comments: List[str], - dependent_tool_names: List[str], -) -> str: - """ - Improve the previous test functions (in Python) to check the given tool policy-items according to the review-comments. - - Args: - prev_impl (str): previous implementation of a Python function. - domain (Domain): Python source code defining available data types and APIs that the test can use. - tool (ToolPolicyItem): Requirements for this tool. - review_comments (List[str]): Review comments on the current implementation. For example, pylint errors or Failed unit-tests. - dependent_tool_names (List[str]): other tool names that this tool depends on. - - Returns: - str: Improved implementation pytest test functions. - - Implementation Rules: - - Do not change the function signatures. - - You can add import statements, but dont remove them. - """ - ... diff --git a/altk/pre_tool/toolguard/toolguard/gen_py/prompts/improve_guard.py b/altk/pre_tool/toolguard/toolguard/gen_py/prompts/improve_guard.py deleted file mode 100644 index 7af2f88..0000000 --- a/altk/pre_tool/toolguard/toolguard/gen_py/prompts/improve_guard.py +++ /dev/null @@ -1,64 +0,0 @@ -from typing import List -from altk.pre_tool.toolguard.toolguard.data_types import Domain, ToolPolicyItem -from mellea import generative - -# from toolguard.gen_py.prompts.python_code import PythonCodeModel - - -@generative -def improve_tool_guard( - prev_impl: str, - domain: Domain, - policy_item: ToolPolicyItem, - dependent_tool_names: List[str], - review_comments: List[str], -) -> str: - """ - Improve the previous tool-call guard implementation (in Python) so that it fully adheres to the given policy and addresses all review comments. - - Args: - prev_impl (str): The previous implementation of the tool-call check. - domain (Domain): Python code defining available data types and other tool interfaces. - policy_item (ToolPolicyItem): Requirements for this tool. - dependent_tool_names (List[str]): Names of other tools that this tool may call to obtain required information. - review_comments (List[str]): Review feedback on the current implementation (e.g., pylint errors, failed unit tests). - - Returns: - str: The improved implementation of the tool-call check. - - Implementation Rules: - - Do not modify the function signature, parameter names, or type annotations. - - All policy requirements must be validated. - - Keep the implementation simple and well-documented. - - Only validate the tool-call arguments; never call the tool itself. - - If additional information is needed beyond the function arguments, use only the APIs of tools listed in `dependent_tool_names`. - - Generate code that enforces the given policy only, do not generate any additional logic that is not explicitly mentioned in the policy. - - **Example: ** - prev_impl = ```python - from typing import * - from airline.airline_types import * - from airline.i_airline import I_Airline - - def guard_Checked_Bag_Allowance_by_Membership_Tier(api: I_Airline, user_id: str, passengers: list[Passenger]): - \"\"\" - Limit to five passengers per reservation. - \"\"\" - pass #FIXME - ``` - - should return something like: - ```python - from typing import * - from airline.airline_types import * - from airline.i_airline import I_Airline - - def guard_Checked_Bag_Allowance_by_Membership_Tier(api: I_Airline, user_id: str, passengers: list[Passenger]): - \"\"\" - Limit to five passengers per reservation. - \"\"\" - if len(passengers) > 5: - raise PolicyViolationException("More than five passengers are not allowed.") - ``` - """ - ... diff --git a/altk/pre_tool/toolguard/toolguard/gen_py/prompts/pseudo_code.py b/altk/pre_tool/toolguard/toolguard/gen_py/prompts/pseudo_code.py deleted file mode 100644 index 02f98b1..0000000 --- a/altk/pre_tool/toolguard/toolguard/gen_py/prompts/pseudo_code.py +++ /dev/null @@ -1,124 +0,0 @@ -from altk.pre_tool.toolguard.toolguard.data_types import Domain, ToolPolicyItem -from mellea import generative - - -@generative -def tool_policy_pseudo_code( - policy_item: ToolPolicyItem, fn_to_analyze: str, domain: Domain -) -> str: - """ - Returns a pseudo code to check business constraints on a tool cool using an API - - Args: - policy_item (ToolPolicyItem): Business policy, in natural language, specifying a constraint on a process involving the tool under analysis. - fn_to_analyze (str): The function signature of the tool under analysis. - domain (Domain): Python code defining available data types and APIs for invoking other tools. - - Returns: - str: A pseudo code descibing how to use the API to check the tool call - - * The available API functions are listed in the `domain.app_api.content`. - * Analyze the API functions' signatures (input and output parameter types). - * You cannot assume other API functions. - * For data objects (dataclasses or Pydantic models), only reference the explicitly declared fields. - * Do not assume the presence of any additional fields. - * Do not assume any implicit logic or relationships between field values (e.g., naming conventions). - * List all the required API calls to check the business constraints. - * If some information is missing, you should call another api function declared in the domain API. - - Examples: - ```python - domain = { - "app_types": { - "file_name": "car_types.py", - "content": ''' - class CarType(Enum): - SEDAN = "sedan" - SUV = "suv" - VAN = "van" - class Car: - plate_num: str - car_type: CarType - class Person: - id: str - driving_licence: str - class Insurance: - doc_id: str - class CarOwnership: - owenr_id: str - start_date: str - end_date: str - ''' - }, - "app_api": { - "file_name": "cars_api.py", - "content": ''' - class CarAPI(ABC): - def buy_car(self, plate_num: str, owner_id: str, insurance_id: str): pass - def get_person(self, id: str) -> Person: pass - def get_insurance(self, id: str) -> Insurance: pass - def get_car(self, plate_num: str) -> Car: pass - def car_ownership_history(self, plate_num: str) -> List[CarOwnership]: pass - def delete_car(self, plate_num: str): pass - def list_all_cars_owned_by(self, id: str) -> List[Car]: pass - def are_relatives(self, person1_id: str, person2_id: str) -> bool: pass - ''' - } - } - ``` - * Example 1: - ``` - tool_policy_pseudo_code( - {"name": "documents exists", "description": "when buying a car, check that the car owner has a driving licence and that the insurance is valid."}, - "buy_car(plate_num: str, owner_id: str, insurance_id: str)", - domain - ) - ``` - may return: - ``` - assert api.get_person(owner_id).driving_licence - assert api.get_insurance(insurance_id) - ``` - - * Example 2: - ``` - tool_policy_pseudo_code( - {"name": "has driving licence", "description": "when buying a car, check that the car owner has a driving licence"}, - "buy_car(plate_num: str, owner_id: str, insurance_id: str)", - domain - ) - ``` - may return: - ``` - assert api.get_insurance(insurance_id) - ``` - - * Example 3: - ``` - tool_policy_pseudo_code( - {"name": "no transfers on holidays", "description": "when buying a car, check that it is not a holiday today"}, - "buy_car(plate_num: str, owner_id: str, insurance_id: str)", - domain - ) - ``` - should return an empty string. - - * Example 4: - ``` - tool_policy_pseudo_code( - {"name": "Not in the same family", - "description": "when buying a van, check that the van was never owned by someone from the buyer's family."}, - "buy_car(plate_num: str, owner_id: str, insurance_id: str)", - domain - ) - ``` - should return: - ``` - car = api.get_car(plate_num) - if car.car_type == CarType.VAN: - history = api.car_ownership_history(plate_num) - for each ownership in history: - assert(not api.are_relatives(ownership.owenr_id, owner_id)) - ``` - """ - ... diff --git a/altk/pre_tool/toolguard/toolguard/gen_py/prompts/~python_code.py b/altk/pre_tool/toolguard/toolguard/gen_py/prompts/~python_code.py deleted file mode 100644 index f3feba2..0000000 --- a/altk/pre_tool/toolguard/toolguard/gen_py/prompts/~python_code.py +++ /dev/null @@ -1,24 +0,0 @@ -from pydantic import BaseModel, Field -import re - -PYTHON_PATTERN = r"^```python\s*\n([\s\S]*)\n```" - - -class PythonCodeModel(BaseModel): - python_code: str = Field( - ..., - ) - - def get_code_content(self) -> str: - code = self.python_code.replace("\\n", "\n") - match = re.match(PYTHON_PATTERN, code) - if match: - return match.group(1) - - return code - - @classmethod - def create(cls, python_code: str) -> "PythonCodeModel": - return PythonCodeModel.model_construct( - python_code=f"```python\n{python_code}\n```" - ) diff --git a/altk/pre_tool/toolguard/toolguard/gen_py/templates/__init__.py b/altk/pre_tool/toolguard/toolguard/gen_py/templates/__init__.py deleted file mode 100644 index e26b90d..0000000 --- a/altk/pre_tool/toolguard/toolguard/gen_py/templates/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -from jinja2 import Environment, PackageLoader, select_autoescape - -from altk.pre_tool.toolguard.toolguard.common.py import path_to_module -from altk.pre_tool.toolguard.toolguard.common.str import to_snake_case - -env = Environment( - loader=PackageLoader("altk.pre_tool.toolguard.toolguard.gen_py", "templates"), - autoescape=select_autoescape(), -) -env.globals["path_to_module"] = path_to_module -env.globals["to_snake_case"] = to_snake_case - - -def load_template(template_name: str): - return env.get_template(template_name) diff --git a/altk/pre_tool/toolguard/toolguard/gen_py/templates/api.j2 b/altk/pre_tool/toolguard/toolguard/gen_py/templates/api.j2 deleted file mode 100644 index ce22327..0000000 --- a/altk/pre_tool/toolguard/toolguard/gen_py/templates/api.j2 +++ /dev/null @@ -1,12 +0,0 @@ -from typing import Any, Dict, List -from abc import ABC, abstractmethod -from datetime import date, datetime -from {{ types_module }} import * - -class {{ class_name }}(ABC): -{% for method in methods %} - @abstractmethod - def {{ method.name }}{{ method.signature }}: - {% if method.doc %}"""{{ method.doc }}"""{% endif %} - pass -{% endfor %} diff --git a/altk/pre_tool/toolguard/toolguard/gen_py/templates/api_impl.j2 b/altk/pre_tool/toolguard/toolguard/gen_py/templates/api_impl.j2 deleted file mode 100644 index ab114b8..0000000 --- a/altk/pre_tool/toolguard/toolguard/gen_py/templates/api_impl.j2 +++ /dev/null @@ -1,11 +0,0 @@ -from typing import Any, Dict, List -from datetime import date, datetime -from {{ types_module }} import * -from {{ api_module }} import * - -class {{ class_name }}({{ api_cls_name }}): -{% for method in methods %} - def {{ method.name }}{{ method.signature }}: - {% if method.doc %}"""{{ method.doc }}"""{% endif %} - {{ method.body }} -{% endfor %} diff --git a/altk/pre_tool/toolguard/toolguard/gen_py/templates/tool_guard.j2 b/altk/pre_tool/toolguard/toolguard/gen_py/templates/tool_guard.j2 deleted file mode 100644 index d4b0e35..0000000 --- a/altk/pre_tool/toolguard/toolguard/gen_py/templates/tool_guard.j2 +++ /dev/null @@ -1,25 +0,0 @@ -from typing import * -{% for imp in extra_imports %}{{ imp }} -{% endfor %} -import {{to_snake_case(domain.app_name)}} -from {{path_to_module(domain.toolguard_common.file_name)}} import PolicyViolationException -from {{ path_to_module(domain.app_types.file_name) }} import * -from {{ path_to_module(domain.app_api.file_name) }} import {{ domain.app_api_class_name }} - -{% for item in items %}from {{ path_to_module(item.file_name) }} import {{ item.guard_fn }} -{% endfor %} - -def {{ method.name }}(api: {{ domain.app_api_class_name }}, {{ method.signature}}): - """ - Checks that a tool call complies to the policies. - - Args: - api ({{ domain.app_api_class_name }}): api to access other tools. -{{method.args_doc_str}} - - Raises: - PolicyViolationException: If the tool call does not comply to the policy. - """ - -{% for item in items %} {{ item.guard_fn }}(api, {{method.args_call}}) -{% endfor %} diff --git a/altk/pre_tool/toolguard/toolguard/gen_py/templates/tool_item_guard.j2 b/altk/pre_tool/toolguard/toolguard/gen_py/templates/tool_item_guard.j2 deleted file mode 100644 index f8deed1..0000000 --- a/altk/pre_tool/toolguard/toolguard/gen_py/templates/tool_item_guard.j2 +++ /dev/null @@ -1,17 +0,0 @@ -from typing import * -{% for imp in extra_imports %}{{ imp }} -{% endfor %} -import {{to_snake_case(domain.app_name)}} -from {{path_to_module(domain.toolguard_common.file_name)}} import PolicyViolationException -from {{ path_to_module(domain.app_types.file_name) }} import * -from {{ path_to_module(domain.app_api.file_name) }} import {{ domain.app_api_class_name }} - -def {{ method.name }}(api: {{ domain.app_api_class_name }}, {{ method.signature }}): - """ - Policy to check: {{ policy }} - - Args: - api ({{ domain.app_api_class_name }}): api to access other tools. -{{method.args_doc_str}} - """ - pass #FIXME diff --git a/altk/pre_tool/toolguard/toolguard/gen_py/tool_dependencies.py b/altk/pre_tool/toolguard/toolguard/gen_py/tool_dependencies.py deleted file mode 100644 index 5149083..0000000 --- a/altk/pre_tool/toolguard/toolguard/gen_py/tool_dependencies.py +++ /dev/null @@ -1,36 +0,0 @@ -import asyncio -import re -from typing import Set -from altk.pre_tool.toolguard.toolguard.data_types import Domain, ToolPolicyItem -from mellea.backends.types import ModelOption -from altk.pre_tool.toolguard.toolguard.gen_py.prompts.pseudo_code import ( - tool_policy_pseudo_code, -) - -MAX_TRIALS = 3 - - -async def tool_dependencies( - policy_item: ToolPolicyItem, tool_signature: str, domain: Domain, trial=0 -) -> Set[str]: - model_options = {ModelOption.TEMPERATURE: 0.8} - pseudo_code = await asyncio.to_thread( # FIXME when melea will support aysnc - lambda: tool_policy_pseudo_code( - policy_item=policy_item, - fn_to_analyze=tool_signature, - domain=domain, - model_options=model_options, - ) # type: ignore - ) - fn_names = _extract_api_calls(pseudo_code) - if all([f"{fn_name}(" in domain.app_api.content for fn_name in fn_names]): - return fn_names - if trial <= MAX_TRIALS: - # as tool_policy_pseudo_code has some temerature, we retry hoping next time the pseudo code will be correct - return await tool_dependencies(policy_item, tool_signature, domain, trial + 1) - raise Exception("Failed to analyze api dependencies") - - -def _extract_api_calls(code: str) -> Set[str]: - pattern = re.compile(r"\bapi\.(\w+)\s*\(") - return set(pattern.findall(code)) diff --git a/altk/pre_tool/toolguard/toolguard/gen_py/tool_guard_generator.py b/altk/pre_tool/toolguard/toolguard/gen_py/tool_guard_generator.py deleted file mode 100644 index 02e4e0f..0000000 --- a/altk/pre_tool/toolguard/toolguard/gen_py/tool_guard_generator.py +++ /dev/null @@ -1,446 +0,0 @@ -import inspect -import os -import asyncio -import logging -from os.path import join -import re -from typing import Callable, List, Tuple - -from altk.pre_tool.toolguard.toolguard.common import py -from altk.pre_tool.toolguard.toolguard.common.llm_py import get_code_content -from altk.pre_tool.toolguard.toolguard.common.py_doc_str import extract_docstr_args -from altk.pre_tool.toolguard.toolguard.common.str import to_snake_case -from altk.pre_tool.toolguard.toolguard.data_types import ( - DEBUG_DIR, - TESTS_DIR, - FileTwin, - RuntimeDomain, - ToolPolicy, - ToolPolicyItem, -) -from altk.pre_tool.toolguard.toolguard.gen_py.consts import ( - guard_fn_module_name, - guard_fn_name, - guard_item_fn_module_name, - guard_item_fn_name, - test_fn_module_name, -) -from altk.pre_tool.toolguard.toolguard.gen_py.tool_dependencies import tool_dependencies -from altk.pre_tool.toolguard.toolguard.runtime import ( - ToolGuardCodeResult, - find_class_in_module, - load_module_from_path, -) -import altk.pre_tool.toolguard.toolguard.gen_py.utils.pytest as pytest -import altk.pre_tool.toolguard.toolguard.gen_py.utils.pyright as pyright -from altk.pre_tool.toolguard.toolguard.gen_py.prompts.gen_tests import ( - generate_init_tests, - improve_tests, -) -from altk.pre_tool.toolguard.toolguard.gen_py.prompts.improve_guard import ( - improve_tool_guard, -) -from altk.pre_tool.toolguard.toolguard.gen_py.templates import load_template - -logger = logging.getLogger(__name__) - -MAX_TOOL_IMPROVEMENTS = 5 -MAX_TEST_GEN_TRIALS = 3 - - -class ToolGuardGenerator: - app_name: str - py_path: str - tool_policy: ToolPolicy - domain: RuntimeDomain - common: FileTwin - - def __init__( - self, - app_name: str, - tool_policy: ToolPolicy, - py_path: str, - domain: RuntimeDomain, - py_env: str, - ) -> None: - self.py_path = py_path - self.app_name = app_name - self.tool_policy = tool_policy - self.domain = domain - self.py_env = py_env - - def start(self): - app_path = join(self.py_path, to_snake_case(self.app_name)) - os.makedirs(app_path, exist_ok=True) - os.makedirs( - join(app_path, to_snake_case(self.tool_policy.tool_name)), exist_ok=True - ) - os.makedirs(join(self.py_path, to_snake_case(DEBUG_DIR)), exist_ok=True) - os.makedirs( - join( - self.py_path, - to_snake_case(DEBUG_DIR), - to_snake_case(self.tool_policy.tool_name), - ), - exist_ok=True, - ) - for item in self.tool_policy.policy_items: - os.makedirs( - join( - self.py_path, - to_snake_case(DEBUG_DIR), - to_snake_case(self.tool_policy.tool_name), - to_snake_case(item.name), - ), - exist_ok=True, - ) - os.makedirs(join(self.py_path, to_snake_case(TESTS_DIR)), exist_ok=True) - - async def generate(self) -> ToolGuardCodeResult: - self.start() - tool_guard, init_item_guards = self._create_initial_tool_guards() - - # Generate guards for all tool items - tests_and_guards = await asyncio.gather( - *[ - self._generate_item_tests_and_guard(item, item_guard) - for item, item_guard in zip( - self.tool_policy.policy_items, init_item_guards - ) - ] - ) - - item_tests, item_guards = zip(*tests_and_guards) - return ToolGuardCodeResult( - tool=self.tool_policy, - guard_fn_name=guard_fn_name(self.tool_policy), - guard_file=tool_guard, - item_guard_files=item_guards, - test_files=item_tests, - ) - - async def _generate_item_tests_and_guard( - self, item: ToolPolicyItem, init_guard: FileTwin - ) -> Tuple[FileTwin | None, FileTwin]: - # Dependencies of this tool - tool_fn_name = to_snake_case(self.tool_policy.tool_name) - tool_fn = self._find_api_function(tool_fn_name) - sig_str = f"{tool_fn_name}{str(inspect.signature(tool_fn))}" - dep_tools = [] - if self.domain.app_api_size > 1: - domain = self.domain.get_definitions_only() # remove runtime fields - dep_tools = list(await tool_dependencies(item, sig_str, domain)) - logger.debug(f"Dependencies of '{item.name}': {dep_tools}") - - # Generate tests - try: - guard_tests = await self._generate_tests(item, init_guard, dep_tools) - except Exception as ex: - logger.warning(f"Tests generation failed for item {item.name} %s", str(ex)) - try: - logger.warning("try to generate the code without tests... %s", str(ex)) - guard = await self._improve_guard(item, init_guard, [], dep_tools) - return None, guard - except Exception as ex: - logger.warning( - "guard generation failed. returning initial guard: %s", str(ex) - ) - return None, init_guard - - # Tests generated, now generate guards - try: - guard = await self._improve_guard_green_loop( - item, init_guard, guard_tests, dep_tools - ) - logger.debug( - f"tool item generated successfully '{item.name}'" - ) # 😄🎉 Happy path - return guard_tests, guard - except Exception as ex: - logger.warning( - "guard generation failed. returning initial guard: %s", str(ex) - ) - return None, init_guard - - # async def tool_dependencies(self, policy_item: ToolPolicyItem, tool_signature: str) -> Set[str]: - # domain = self.domain.get_definitions_only() #remove runtime fields - # pseudo_code = await tool_policy_pseudo_code(policy_item, tool_signature, domain) - # dep_tools = await extract_api_dependencies_from_pseudo_code(pseudo_code, domain) - # return set(dep_tools) - - async def _generate_tests( - self, item: ToolPolicyItem, guard: FileTwin, dep_tools: List[str] - ) -> FileTwin: - test_file_name = join( - TESTS_DIR, self.tool_policy.tool_name, f"{test_fn_module_name(item)}.py" - ) - errors = [] - test_file = None - trials = "a b c".split() - for trial_no in trials: - logger.debug( - f"Generating tests iteration '{trial_no}' for tool {self.tool_policy.tool_name} '{item.name}'." - ) - domain = self.domain.get_definitions_only() # remove runtime fields - first_time = trial_no == "a" - if first_time: - # FIXME when melea will support aysnc - res = await asyncio.to_thread( - lambda: generate_init_tests( - fn_src=guard, - policy_item=item, - domain=domain, # noqa: B023 - dependent_tool_names=dep_tools, - ) - ) - else: - assert test_file - # FIXME when melea will support aysnc - res = await asyncio.to_thread( - lambda: improve_tests( - prev_impl=test_file.content, # noqa: B023 - domain=domain, # noqa: B023 - policy_item=item, - review_comments=errors, # noqa: B023 - dependent_tool_names=dep_tools, - ) - ) - - test_file = FileTwin( - file_name=test_file_name, content=get_code_content(res) - ).save(self.py_path) - test_file.save_as(self.py_path, self.debug_dir(item, f"test_{trial_no}.py")) - - syntax_report = pyright.run(self.py_path, test_file.file_name, self.py_env) - FileTwin( - file_name=self.debug_dir(item, f"test_{trial_no}_pyright.json"), - content=syntax_report.model_dump_json(indent=2), - ).save(self.py_path) - - if syntax_report.summary.errorCount > 0: - logger.warning( - f"{syntax_report.summary.errorCount} syntax errors in tests iteration '{trial_no}' in item '{item.name}'." - ) - errors = syntax_report.list_error_messages(test_file.content) - continue - - # syntax ok, try to run it... - logger.debug( - f"Generated Tests for tool '{self.tool_policy.tool_name}' '{item.name}'(trial='{trial_no}')" - ) - report_file_name = self.debug_dir(item, f"test_{trial_no}_pytest.json") - pytest_report = pytest.run( - self.py_path, test_file.file_name, report_file_name - ) - if ( - pytest_report.all_tests_collected_successfully() - and pytest_report.non_empty_tests() - ): - return test_file - if not pytest_report.non_empty_tests(): # empty test set - errors = ["empty set of generated unit tests is not allowed"] - else: - errors = pytest_report.list_errors() - - raise Exception("Generated tests contain syntax errors") - - async def _improve_guard_green_loop( - self, - item: ToolPolicyItem, - guard: FileTwin, - tests: FileTwin, - dep_tools: List[str], - ) -> FileTwin: - trial_no = 0 - while trial_no < MAX_TOOL_IMPROVEMENTS: - pytest_report_file = self.debug_dir(item, f"guard_{trial_no}_pytest.json") - errors = pytest.run( - self.py_path, tests.file_name, pytest_report_file - ).list_errors() - if errors: - logger.debug(f"'{item.name}' guard function tests failed. Retrying...") - - trial_no += 1 - try: - guard = await self._improve_guard( - item, guard, errors, dep_tools, trial_no - ) - except Exception: - continue # probably a syntax error in the generated code. lets retry... - else: - logger.debug( - f"'{item.name}' guard function generated succefully and is Green 😄🎉. " - ) - return guard # Green - - raise Exception( - f"Failed {MAX_TOOL_IMPROVEMENTS} times to generate guard function for tool {to_snake_case(self.tool_policy.tool_name)} policy: {item.name}" - ) - - async def _improve_guard( - self, - item: ToolPolicyItem, - prev_guard: FileTwin, - review_comments: List[str], - dep_tools: List[str], - round: int = 0, - ) -> FileTwin: - module_name = guard_item_fn_module_name(item) - errors = [] - trials = "a b c".split() - for trial in trials: - logger.debug( - f"Improving guard function '{module_name}'... (trial = {round}.{trial})" - ) - domain = self.domain.get_definitions_only() # omit runtime fields - prev_python = get_code_content(prev_guard.content) - # FIXME when melea will support aysnc - res = await asyncio.to_thread( - lambda: improve_tool_guard( - prev_impl=prev_python, # noqa: B023 - domain=domain, # noqa: B023 - policy_item=item, - dependent_tool_names=dep_tools, - review_comments=review_comments + errors, # noqa: B023 - ) - ) - - guard = FileTwin( - file_name=prev_guard.file_name, content=get_code_content(res) - ).save(self.py_path) - guard.save_as( - self.py_path, self.debug_dir(item, f"guard_{round}_{trial}.py") - ) - - syntax_report = pyright.run(self.py_path, guard.file_name, self.py_env) - FileTwin( - file_name=self.debug_dir(item, f"guard_{round}_{trial}.pyright.json"), - content=syntax_report.model_dump_json(indent=2), - ).save(self.py_path) - logger.info( - f"Generated function {module_name} with {syntax_report.summary.errorCount} errors." - ) - - if syntax_report.summary.errorCount > 0: - # Syntax errors. retry... - errors = syntax_report.list_error_messages(guard.content) - continue - - guard.save_as(self.py_path, self.debug_dir(item, f"guard_{round}_final.py")) - return ( - guard # Happy path. improved vesion of the guard with no syntax errors - ) - - # Failed to generate valid python after iterations - raise Exception(f"Syntax error generating for tool '{item.name}'.") - - def _find_api_function(self, tool_fn_name: str): - with py.temp_python_path(self.py_path): - module = load_module_from_path(self.domain.app_api.file_name, self.py_path) - assert module, f"File not found {self.domain.app_api.file_name}" - cls = find_class_in_module(module, self.domain.app_api_class_name) - return getattr(cls, tool_fn_name) - - def _create_initial_tool_guards(self) -> Tuple[FileTwin, List[FileTwin]]: - tool_fn_name = to_snake_case(self.tool_policy.tool_name) - tool_fn = self._find_api_function(tool_fn_name) - assert tool_fn, f"Function not found, {tool_fn_name}" - - # __init__.py - path = join(to_snake_case(self.app_name), tool_fn_name, "__init__.py") - FileTwin(file_name=path, content="").save(self.py_path) - - # item guards files - item_files = [ - self._create_item_module(item, tool_fn) - for item in self.tool_policy.policy_items - ] - # tool guard file - tool_file = self._create_tool_module(tool_fn, item_files) - - # Save to debug folder - for item_guard_fn, policy_item in zip( - item_files, self.tool_policy.policy_items - ): - item_guard_fn.save_as(self.py_path, self.debug_dir(policy_item, "g0.py")) - - return (tool_file, item_files) - - def _create_tool_module( - self, tool_fn: Callable, item_files: List[FileTwin] - ) -> FileTwin: - file_name = join( - to_snake_case(self.app_name), - to_snake_case(self.tool_policy.tool_name), - py.py_extension(guard_fn_module_name(self.tool_policy)), - ) - items = [ - {"guard_fn": guard_item_fn_name(item), "file_name": file.file_name} - for (item, file) in zip(self.tool_policy.policy_items, item_files) - ] - sig = inspect.signature(tool_fn) - sig_str = self._signature_str(sig) - args_call = ", ".join([p for p in sig.parameters if p != "self"]) - args_doc_str = extract_docstr_args(tool_fn) - extra_imports = [] - if "Decimal" in sig_str: - extra_imports.append("from decimal import Decimal") - - return FileTwin( - file_name=file_name, - content=load_template("tool_guard.j2").render( - domain=self.domain, - method={ - "name": guard_fn_name(self.tool_policy), - "signature": sig_str, - "args_call": args_call, - "args_doc_str": args_doc_str, - }, - items=items, - extra_imports=extra_imports, - ), - ).save(self.py_path) - - def _signature_str(self, sig: inspect.Signature): - sig_str = str(sig) - sig_str = sig_str[ - sig_str.find("self,") + len("self,") : sig_str.rfind(")") - ].strip() - # Strip module prefixes like airline.airline_types.XXX → XXX - clean_sig_str = re.sub(r"\b(?:\w+\.)+(\w+)", r"\1", sig_str) - return clean_sig_str - - def _create_item_module( - self, tool_item: ToolPolicyItem, tool_fn: Callable - ) -> FileTwin: - file_name = join( - to_snake_case(self.app_name), - to_snake_case(self.tool_policy.tool_name), - py.py_extension(guard_item_fn_module_name(tool_item)), - ) - sig_str = self._signature_str(inspect.signature(tool_fn)) - args_doc_str = extract_docstr_args(tool_fn) - extra_imports = [] - if "Decimal" in sig_str: - extra_imports.append("from decimal import Decimal") - return FileTwin( - file_name=file_name, - content=load_template("tool_item_guard.j2").render( - domain=self.domain, - method={ - "name": guard_item_fn_name(tool_item), - "signature": sig_str, - "args_doc_str": args_doc_str, - }, - policy=tool_item.description, - extra_imports=extra_imports, - ), - ).save(self.py_path) - - def debug_dir(self, policy_item: ToolPolicyItem, dir: str): - return join( - DEBUG_DIR, - to_snake_case(self.tool_policy.tool_name), - to_snake_case(policy_item.name), - dir, - ) diff --git a/altk/pre_tool/toolguard/toolguard/gen_py/utils/__init__.py b/altk/pre_tool/toolguard/toolguard/gen_py/utils/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/altk/pre_tool/toolguard/toolguard/gen_py/utils/datamodel_codegen.py b/altk/pre_tool/toolguard/toolguard/gen_py/utils/datamodel_codegen.py deleted file mode 100644 index 7d3d33e..0000000 --- a/altk/pre_tool/toolguard/toolguard/gen_py/utils/datamodel_codegen.py +++ /dev/null @@ -1,34 +0,0 @@ -import subprocess - - -def run(oas_file: str): - # see https://github.com/koxudaxi/datamodel-code-generator - res = subprocess.run( - [ - "datamodel-codegen", - "--use-field-description", - "--use-schema-description", - "--output-model-type", - "pydantic_v2.BaseModel", # "typing.TypedDict", - "--collapse-root-models", - # "--force-optional", - "--reuse-model", - "--enum-field-as-literal", - "all", - "--input-file-type", - "openapi", - "--use-operation-id-as-name", - "--openapi-scopes", - "paths", - "parameters", - "schemas", - "--input", - oas_file, - # "--output", domain_file - ], - capture_output=True, - text=True, - ) - if res.returncode != 0: - raise Exception(res.stderr) - return res.stdout diff --git a/altk/pre_tool/toolguard/toolguard/gen_py/utils/pyright.py b/altk/pre_tool/toolguard/toolguard/gen_py/utils/pyright.py deleted file mode 100644 index fd92d81..0000000 --- a/altk/pre_tool/toolguard/toolguard/gen_py/utils/pyright.py +++ /dev/null @@ -1,109 +0,0 @@ -import json -import os -import subprocess -from pydantic import BaseModel -from typing import List, Optional - -from altk.pre_tool.toolguard.toolguard.data_types import FileTwin - -ERROR = "error" -WARNING = "warning" - - -class Position(BaseModel): - line: int - character: int - - -class Range(BaseModel): - start: Position - end: Position - - -class GeneralDiagnostic(BaseModel): - file: str - severity: str - message: str - range: Range - rule: Optional[str] = None - - -class Summary(BaseModel): - filesAnalyzed: int - errorCount: int - warningCount: int - informationCount: int - timeInSec: float - - -class DiagnosticsReport(BaseModel): - version: str - time: str - generalDiagnostics: List[GeneralDiagnostic] - summary: Summary - - def list_error_messages(self, file_content: str) -> List[str]: - msgs = set() - for d in self.generalDiagnostics: - if d.severity == ERROR: - msgs.add( - f"Syntax error: {d.message}. code block: '{get_text_by_range(file_content, d.range)}, '" - ) - return list(msgs) - - -def get_text_by_range(file_content: str, rng: Range) -> str: - lines = file_content.splitlines() - - if rng.start.line == rng.end.line: - # Single-line span - return lines[rng.start.line][rng.start.character : rng.end.character] - - # Multi-line span - selected_lines = [] - selected_lines.append( - lines[rng.start.line][rng.start.character :] - ) # First line, from start.character - for line_num in range(rng.start.line + 1, rng.end.line): - selected_lines.append(lines[line_num]) # Full middle lines - selected_lines.append( - lines[rng.end.line][: rng.end.character] - ) # Last line, up to end.character - - return "\n".join(selected_lines) - - -def run(folder: str, py_file: str, venv_name: str) -> DiagnosticsReport: - py_path = os.path.join(venv_name, "bin", "python3") - res = subprocess.run( - [ - "pyright", - # "--venv-path", venv_path, - "--pythonpath", - py_path, - "--outputjson", - py_file, - ], - cwd=folder, - capture_output=True, - text=True, - ) - # if res.returncode !=0: - # raise Exception(res.stderr) - - data = json.loads(res.stdout) - return DiagnosticsReport.model_validate(data) - - -def config(folder: str): - cfg = { - "typeCheckingMode": "basic", - "reportOptionalIterable": WARNING, - "reportArgumentType": WARNING, # "Object of type \"None\" cannot be used as iterable value", - "reportOptionalMemberAccess": WARNING, - "reportOptionalSubscript": WARNING, - "reportAttributeAccessIssue": ERROR, - } - FileTwin(file_name="pyrightconfig.json", content=json.dumps(cfg, indent=2)).save( - folder - ) diff --git a/altk/pre_tool/toolguard/toolguard/gen_py/utils/pytest.py b/altk/pre_tool/toolguard/toolguard/gen_py/utils/pytest.py deleted file mode 100644 index 9113c19..0000000 --- a/altk/pre_tool/toolguard/toolguard/gen_py/utils/pytest.py +++ /dev/null @@ -1,198 +0,0 @@ -from enum import Enum -import json -import os -from os.path import join -import subprocess -import sys -from typing import Any, List, Dict, Optional -from pydantic import BaseModel, Field -from contextlib import contextmanager - -from altk.pre_tool.toolguard.toolguard.data_types import FileTwin - - -class StrEnum(str, Enum): - """An abstract base class for string-based enums.""" - - pass - - -class TestOutcome(StrEnum): - passed = "passed" - failed = "failed" - - -class TracebackEntry(BaseModel): - path: str - lineno: int - message: str - - -class CrashInfo(BaseModel): - path: str - lineno: int - message: str - - -class CallInfo(BaseModel): - duration: float - outcome: TestOutcome - crash: Optional[CrashInfo] = None - traceback: Optional[List[TracebackEntry]] = None - longrepr: Optional[str] = None - - -class TestPhase(BaseModel): - duration: float - outcome: TestOutcome - - -class TestResult(BaseModel): - nodeid: str - lineno: int - outcome: TestOutcome - keywords: List[str] - setup: TestPhase - call: CallInfo - user_properties: Optional[List[Any]] = None - teardown: TestPhase - - -class ResultEntry(BaseModel): - nodeid: str - type: str - lineno: Optional[int] = None - - -class Collector(BaseModel): - nodeid: str - outcome: TestOutcome - result: List[ResultEntry] - longrepr: Optional[str] = None - - -class Summary(BaseModel): - failed: Optional[int] = 0 - total: int - collected: int - - -class TestReport(BaseModel): - created: float - duration: float - exitcode: int - root: str - environment: Dict[str, str] - summary: Summary - collectors: List[Collector] = Field(default=[]) - tests: List[TestResult] - - def all_tests_passed(self) -> bool: - return all([test.outcome == TestOutcome.passed for test in self.tests]) - - def all_tests_collected_successfully(self) -> bool: - return all([col.outcome == TestOutcome.passed for col in self.collectors]) - - def non_empty_tests(self) -> bool: - return self.summary.total > 0 - - def list_errors(self) -> List[str]: - errors = set() - - # Python errors in the function under test - for col in self.collectors: - if col.outcome == TestOutcome.failed and col.longrepr: - errors.add(col.longrepr) - - # applicative test failure - for test in self.tests: - if test.outcome == TestOutcome.failed: - error = test.call.crash.message - if test.user_properties: - case_desc = test.user_properties[0].get("docstring") - if case_desc: - error = f"""Test case {case_desc} failed with the following message:\n {test.call.crash.message}""" - errors.add(error) - return list(errors) - - -def run(folder: str, test_file: str, report_file: str) -> TestReport: - # _run_in_subprocess(folder, test_file, report_file) - _run_safe_python(folder, test_file, report_file) - - report = read_test_report(os.path.join(folder, report_file)) - # overwrite it with indented version - with open(os.path.join(folder, report_file), "w", encoding="utf-8") as f: - json.dump(report.model_dump(), f, indent=2) - - return report - - -@contextmanager -def temp_sys_path(path): - """Temporarily insert a path into sys.path.""" - sys.path.insert(0, path) - try: - yield - finally: - try: - sys.path.remove(path) - except ValueError: - pass - - -def _run_safe_python(folder: str, test_file: str, report_file: str): - from smolagents.local_python_executor import LocalPythonExecutor - - exec = LocalPythonExecutor( - additional_authorized_imports=["pytest"], - max_print_outputs_length=None, - additional_functions=None, - ) - exec.static_tools = {"temp_sys_path": temp_sys_path} - code = f""" -import pytest -with temp_sys_path("{folder}") - pytest.main(["{join(folder, test_file)}", "--quiet", "--json-report", "--json-report-file={join(folder, report_file)}"]) -""" - out = exec(code) - return out - - -def _run_in_subprocess(folder: str, test_file: str, report_file: str): - subprocess.run( - [ - "pytest", - test_file, - # "--verbose", - "--quiet", - "--json-report", - f"--json-report-file={report_file}", - ], - env={**os.environ, "PYTHONPATH": "."}, - cwd=folder, - ) - - -def configure(folder: str): - """adds the test function docstring to the output report""" - - hook = """ -import pytest - -def pytest_runtest_protocol(item, nextitem): - docstring = item.function.__doc__ - if docstring: - item.user_properties.append(("docstring", docstring)) -""" - FileTwin(file_name="conftest.py", content=hook).save(folder) - - -def read_test_report(file_path: str) -> TestReport: - with open(file_path, "r") as file: - data = json.load(file) - return TestReport.model_validate(data, strict=False) - - -# report = read_test_report("/Users/davidboaz/Documents/GitHub/gen_policy_validator/tau_airline/output/2025-03-12 08:54:16/pytest_report.json") -# print(report.summary.failed) diff --git a/altk/pre_tool/toolguard/toolguard/gen_py/utils/venv.py b/altk/pre_tool/toolguard/toolguard/gen_py/utils/venv.py deleted file mode 100644 index 3220602..0000000 --- a/altk/pre_tool/toolguard/toolguard/gen_py/utils/venv.py +++ /dev/null @@ -1,20 +0,0 @@ -import subprocess -import sys -import ensurepip - - -def run(venv_dir: str, packages: list[str]): - # # Bootstrap pip if not present - try: - import pip # noqa: F401 - except ImportError: - ensurepip.bootstrap(upgrade=True) - - subprocess.run([sys.executable, "-m", "pip", "install"] + packages, check=True) - - # # Create the virtual environment - # venv.create(venv_dir, with_pip=True) - # # - # # #install packages - # pip_executable = os.path.join(venv_dir, "bin", "pip") - # subprocess.run([pip_executable, "install"] + packages, check=True) diff --git a/altk/pre_tool/toolguard/toolguard/llm/__init__.py b/altk/pre_tool/toolguard/toolguard/llm/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/altk/pre_tool/toolguard/toolguard/llm/i_tg_llm.py b/altk/pre_tool/toolguard/toolguard/llm/i_tg_llm.py deleted file mode 100644 index df658e4..0000000 --- a/altk/pre_tool/toolguard/toolguard/llm/i_tg_llm.py +++ /dev/null @@ -1,12 +0,0 @@ -from abc import ABC, abstractmethod -from typing import List, Dict - - -class I_TG_LLM(ABC): - @abstractmethod - async def chat_json(self, messages: List[Dict]) -> Dict: - pass - - @abstractmethod - async def generate(self, messages: List[Dict]) -> str: - pass diff --git a/altk/pre_tool/toolguard/toolguard/llm/tg_llmevalkit.py b/altk/pre_tool/toolguard/toolguard/llm/tg_llmevalkit.py deleted file mode 100644 index c983016..0000000 --- a/altk/pre_tool/toolguard/toolguard/llm/tg_llmevalkit.py +++ /dev/null @@ -1,20 +0,0 @@ -from altk.pre_tool.toolguard.toolguard.llm.i_tg_llm import I_TG_LLM -from altk.core.llm import ValidatingLLMClient - - -class TG_LLMEval(I_TG_LLM): - def __init__(self, llm_client: ValidatingLLMClient): - if not isinstance(llm_client, ValidatingLLMClient): - print("llm_client is a ValidatingLLMClient") - exit(1) - self.llm_client = llm_client - - async def chat_json(self, messages: list[dict], schema=dict) -> dict: - return self.llm_client.generate( - prompt=messages, schema=schema, retries=5, schema_field=None - ) - - async def generate(self, messages: list[dict]) -> str: - return self.llm_client.generate( - prompt=messages, schema=str, retries=5, schema_field=None - ) diff --git a/altk/pre_tool/toolguard/toolguard/logging_utils.py b/altk/pre_tool/toolguard/toolguard/logging_utils.py deleted file mode 100644 index 941edc4..0000000 --- a/altk/pre_tool/toolguard/toolguard/logging_utils.py +++ /dev/null @@ -1,32 +0,0 @@ -import logging - - -def add_log_file_handler(log_file: str): - formatter = logging.Formatter( - fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - ) - file_handler = logging.FileHandler(log_file, encoding="utf-8") - file_handler.setFormatter(formatter) - logging.getLogger().addHandler(file_handler) - - -def init_logging(): - logging.getLogger().setLevel(logging.INFO) # Default for other libraries - logging.getLogger("toolguard").setLevel(logging.DEBUG) # debug for our library - logging.getLogger("mellea").setLevel(logging.DEBUG) - init_log_console_handler() - - -def init_log_console_handler(): - formatter = logging.Formatter( - fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - ) - - # Set up console handler - console_handler = logging.StreamHandler() - console_handler.setFormatter(formatter) - - # Set up the root logger - logging.basicConfig(level=logging.INFO, handlers=[console_handler]) diff --git a/altk/pre_tool/toolguard/toolguard/runtime.py b/altk/pre_tool/toolguard/toolguard/runtime.py deleted file mode 100644 index 60b2c6b..0000000 --- a/altk/pre_tool/toolguard/toolguard/runtime.py +++ /dev/null @@ -1,190 +0,0 @@ -import inspect -import json -import os -from types import ModuleType -from typing import Any, Dict, List, Optional, Type, Callable, TypeVar -from pydantic import BaseModel -import importlib.util - -import functools -from .data_types import API_PARAM, RESULTS_FILENAME, FileTwin, RuntimeDomain, ToolPolicy - -from abc import ABC, abstractmethod - - -class IToolInvoker(ABC): - T = TypeVar("T") - - @abstractmethod - def invoke( - self, toolname: str, arguments: Dict[str, Any], return_type: Type[T] - ) -> T: ... - - -def load_toolguards( - directory: str, filename: str = RESULTS_FILENAME -) -> "ToolguardRuntime": - full_path = os.path.join(directory, filename) - with open(full_path, "r", encoding="utf-8") as f: - data = json.load(f) - result = ToolGuardsCodeGenerationResult(**data) - return ToolguardRuntime(result, directory) - - -class ToolGuardCodeResult(BaseModel): - tool: ToolPolicy - guard_fn_name: str - guard_file: FileTwin - item_guard_files: List[FileTwin | None] - test_files: List[FileTwin | None] - - -class ToolGuardsCodeGenerationResult(BaseModel): - root_dir: str - domain: RuntimeDomain - tools: Dict[str, ToolGuardCodeResult] - - def save( - self, directory: str, filename: str = RESULTS_FILENAME - ) -> "ToolGuardsCodeGenerationResult": - full_path = os.path.join(directory, filename) - with open(full_path, "w", encoding="utf-8") as f: - json.dump(self.model_dump(), f, indent=2) - return self - - -class ToolguardRuntime: - def __init__(self, result: ToolGuardsCodeGenerationResult, ctx_dir: str) -> None: - self._ctx_dir = ctx_dir - self._result = result - self._guards = {} - for tool_name, tool_result in result.tools.items(): - module = load_module_from_path(tool_result.guard_file.file_name, ctx_dir) - guard_fn = find_function_in_module(module, tool_result.guard_fn_name) - assert guard_fn, "Guard not found" - self._guards[tool_name] = guard_fn - - def _make_args( - self, guard_fn: Callable, args: dict, delegate: IToolInvoker - ) -> Dict[str, Any]: - sig = inspect.signature(guard_fn) - guard_args = {} - for p_name, param in sig.parameters.items(): - if p_name == API_PARAM: - module = load_module_from_path( - self._result.domain.app_api_impl.file_name, self._ctx_dir - ) - clazz = find_class_in_module( - module, self._result.domain.app_api_impl_class_name - ) - assert clazz, ( - f"class {self._result.domain.app_api_impl_class_name} not found in {self._result.domain.app_api_impl.file_name}" - ) - guard_args[p_name] = clazz(delegate) - else: - arg = args.get(p_name) - if inspect.isclass(param.annotation) and issubclass( - param.annotation, BaseModel - ): - guard_args[p_name] = param.annotation.model_construct(**arg) - else: - guard_args[p_name] = arg - return guard_args - - def check_toolcall(self, tool_name: str, args: dict, delegate: IToolInvoker): - guard_fn = self._guards.get(tool_name) - if guard_fn is None: # No guard assigned to this tool - return - guard_fn(**self._make_args(guard_fn, args, delegate)) - - -def file_to_module(file_path: str): - return file_path.removesuffix(".py").replace("/", ".") - - -def load_module_from_path(file_path: str, py_root: str) -> ModuleType: - full_path = os.path.abspath(os.path.join(py_root, file_path)) - if not os.path.exists(full_path): - raise ImportError(f"Module file does not exist: {full_path}") - - module_name = file_to_module(file_path) - - spec = importlib.util.spec_from_file_location(module_name, full_path) - if spec is None or spec.loader is None: - raise ImportError(f"Could not load module spec from {full_path}") - - module = importlib.util.module_from_spec(spec) - try: - spec.loader.exec_module(module) # type: ignore - except Exception as e: - raise ImportError(f"Failed to execute module '{module_name}': {e}") from e - - return module - - -def find_function_in_module(module: ModuleType, function_name: str): - func = getattr(module, function_name, None) - if func is None or not inspect.isfunction(func): - raise AttributeError( - f"Function '{function_name}' not found in module '{module.__name__}'" - ) - return func - - -def find_class_in_module(module: ModuleType, class_name: str) -> Optional[Type]: - cls = getattr(module, class_name, None) - if isinstance(cls, type): - return cls - return None - - -T = TypeVar("T") - - -def guard_methods(obj: T, guards_folder: str) -> T: - """Wraps all public bound methods of the given instance using the given wrapper.""" - for attr_name in dir(obj): - if attr_name.startswith("_"): - continue - attr = getattr(obj, attr_name) - if callable(attr): - wrapped = guard_before_call(guards_folder)(attr) - setattr(obj, attr_name, wrapped) - return obj - - -class ToolMethodsInvoker(IToolInvoker): - def __init__(self, object: object) -> None: - self._obj = object - - def invoke(self, toolname: str, arguments: Dict[str, Any], model: Type[T]) -> T: - mtd = getattr(self._obj, toolname) - assert callable(mtd), f"Tool {toolname} was not found" - return mtd(**arguments) - - -class ToolFunctionsInvoker(IToolInvoker): - def __init__(self, funcs: List[Callable]) -> None: - self._funcs_by_name = {func.__name__: func for func in funcs} - - def invoke(self, toolname: str, arguments: Dict[str, Any], model: Type[T]) -> T: - func = self._funcs_by_name.get(toolname) - assert callable(func), f"Tool {toolname} was not found" - return func(**arguments) - - -def guard_before_call(guards_folder: str) -> Callable[[Callable], Callable]: - """Decorator factory that logs function calls to the given logfile.""" - toolguards = load_toolguards(guards_folder) - - def decorator(func: Callable) -> Callable: - @functools.wraps(func) - def wrapper(*args: Any, **kwargs: Any) -> Any: - toolguards.check_toolcall( - func.__name__, kwargs, ToolMethodsInvoker(func.__self__) - ) - return func(*args, **kwargs) - - return wrapper # type: ignore - - return decorator diff --git a/altk/pre_tool/toolguard/toolguard/tool_policy_extractor/__init__.py b/altk/pre_tool/toolguard/toolguard/tool_policy_extractor/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/altk/pre_tool/toolguard/toolguard/tool_policy_extractor/create_oas_summary.py b/altk/pre_tool/toolguard/toolguard/tool_policy_extractor/create_oas_summary.py deleted file mode 100644 index 68bd025..0000000 --- a/altk/pre_tool/toolguard/toolguard/tool_policy_extractor/create_oas_summary.py +++ /dev/null @@ -1,254 +0,0 @@ -import json -from typing import Any, Dict, List - -from altk.pre_tool.toolguard.toolguard.tool_policy_extractor.text_tool_policy_generator import ( - ToolInfo, -) - - -class OASSummarizer: - def __init__(self, oas: Dict[str, Any]): - self.oas = oas - self.components = oas.get("components", {}).get("schemas", {}) - - def summarize(self) -> List[ToolInfo]: - operations = [] - for path, methods in self.oas.get("paths", {}).items(): - for method, operation in methods.items(): - summary = self._format_operation(path, method, operation) - tool_info = ToolInfo( - name=summary.get("name"), - description=summary.get("description"), - signature=summary.get("signature"), - parameters=summary.get("params"), - full_description=json.dumps(summary, indent=4), - ) - operations.append(tool_info) - operations.append(tool_info) - return operations - - def _format_operation( - self, path: str, method: str, operation: Dict[str, Any] - ) -> Dict[str, Any]: - operation_id = operation.get( - "operationId", f"{method}_{path.strip('/').replace('/', '_')}" - ) - class_name = operation_id - description = operation.get("description", "") - request_body = operation.get("requestBody", {}) - params = self._parse_request_body(request_body) if request_body else {} - signature = self._generate_signature( - class_name, params, operation.get("responses", {}) - ) - example = operation.get( - "x-input-examples", [] - ) # self._generate_example(class_name, params) - output_examples = operation.get( - "x-output-examples", [] - ) # self._parse_response_examples(operation.get("responses", {})) - - return { - "name": class_name, - "signature": signature, - "description": description, - "params": params, - "examples": [example], - "output_examples": output_examples, - } - - def _parse_request_body(self, request_body: Dict[str, Any]) -> Dict[str, Any]: - content = request_body.get("content", {}).get("application/json", {}) - schema = self._resolve_ref(content.get("schema", {})) - props = schema.get("properties", {}) - required = schema.get("required", []) - params = {} - for param_name, param_schema in props.items(): - resolved_schema = self._resolve_ref(param_schema) - param_type = self._resolve_schema_type(resolved_schema) - param_desc = resolved_schema.get("description", "") - params[param_name] = { - "type": param_type, - "description": param_desc, - "required": param_name in required, - } - return params - - # def _resolve_schema_type(self, schema: Dict[str, Any]) -> str: - # if "anyOf" in schema: - # return "Union[" + ", ".join(self._resolve_schema_type(s) for s in schema["anyOf"]) + "]" - # if "oneOf" in schema: - # return "Union[" + ", ".join(self._resolve_schema_type(s) for s in schema["oneOf"]) + "]" - # if "$ref" in schema: - # return self._resolve_schema_type(self._resolve_ref(schema)) - # if schema.get("type") == "array": - # item_type = self._resolve_schema_type(schema.get("items", {})) - # return f"List[{item_type}]" - # if schema.get("type") == "object": - # return "Dict[str, Any]" - # - # return { - # "string": "str", - # "integer": "int", - # "number": "float", - # "boolean": "bool", - # "object": "Dict[str, Any]", - # }.get(schema.get("type", "Any"), "Any") - - def _resolve_schema_type(self, schema: Dict[str, Any]) -> str: - if "anyOf" in schema: - return ( - "Union[" - + ", ".join(self._resolve_schema_type(s) for s in schema["anyOf"]) - + "]" - ) - if "oneOf" in schema: - return ( - "Union[" - + ", ".join(self._resolve_schema_type(s) for s in schema["oneOf"]) - + "]" - ) - if "$ref" in schema: - return self._resolve_schema_type(self._resolve_ref(schema)) - if schema.get("type") == "array": - item_type = self._resolve_schema_type(schema.get("items", {})) - return f"List[{item_type}]" - if schema.get("type") == "object": - return "Dict[str, Any]" - - type_value = schema.get("type", "Any") - if isinstance(type_value, list): - # Filter out "null" and resolve remaining types - non_null_types = [t for t in type_value if t != "null"] - if not non_null_types: - return "Optional[Any]" - if len(non_null_types) == 1: - base_type = self._resolve_schema_type( - {**schema, "type": non_null_types[0]} - ) - else: - base_type = ( - "Union[" - + ", ".join( - self._resolve_schema_type({"type": t}) for t in non_null_types - ) - + "]" - ) - return f"Optional[{base_type}]" if "null" in type_value else base_type - - return { - "string": "str", - "integer": "int", - "number": "float", - "boolean": "bool", - "object": "Dict[str, Any]", - }.get(type_value, "Any") - - def _resolve_ref(self, schema: Dict[str, Any]) -> Dict[str, Any]: - if isinstance(schema, Dict): - if "$ref" in schema: - ref_path = schema["$ref"] - if ref_path.startswith("#/components/schemas/"): - key = ref_path.split("/")[-1] - return self.components.get(key, {}) - return schema - return schema - - def _generate_signature( - self, class_name: str, params: Dict[str, Any], responses: Dict[str, Any] - ) -> str: - args = ", ".join(f"{name}: {meta['type']}" for name, meta in params.items()) - out = "str" - if responses and "200" in responses: - content = responses["200"]["content"] - app_json = content.get("application/json", {}) - schema = self._resolve_ref(app_json.get("schema", {})) - out = self._resolve_schema_type(schema) - return f"{class_name}({args}) -> {out}" - - def _generate_example(self, class_name: str, params: Dict[str, Any]) -> str: - args = ", ".join( - '"example_string"' if meta["type"].startswith("str") else "0" - for _, meta in params.items() - ) - return f"{class_name}({args})" - - def _parse_response_examples(self, responses: Dict[str, Any]) -> List[str]: - examples = [] - for response in responses.values(): - content = response.get("content", {}) - app_json = content.get("application/json", {}) - schema = self._resolve_ref(app_json.get("schema", {})) - - if "example" in app_json: - example_data = app_json["example"] - elif "examples" in app_json: - example_data = next(iter(app_json["examples"].values()), {}).get( - "value" - ) - else: - example_data = self._construct_example_from_schema(schema) - - if example_data is not None: - try: - examples.append(json.dumps(example_data)) - except Exception: - examples.append(str(example_data)) - return examples - - def _construct_example_from_schema(self, schema: Dict[str, Any]) -> Any: - schema = self._resolve_ref(schema) - if not isinstance(schema, Dict): - return schema - schema_type = schema.get("type") - - if schema_type == "object": - if "additionalProperties" in schema: - value_schema = self._resolve_ref(schema["additionalProperties"]) - return { - "example_key": self._construct_example_from_schema(value_schema) - } - props = schema.get("properties", {}) - return { - key: self._construct_example_from_schema(self._resolve_ref(subschema)) - for key, subschema in props.items() - } - - if schema_type == "array": - item_schema = self._resolve_ref(schema.get("items", {})) - return [self._construct_example_from_schema(item_schema)] - - if schema_type == "string": - return schema.get("example", "example_string") - if schema_type == "integer": - return schema.get("example", 42) - if schema_type == "number": - return schema.get("example", 3.14) - if schema_type == "boolean": - return schema.get("example", True) - - if "anyOf" in schema: - return self._construct_example_from_schema(schema["anyOf"][0]) - if "oneOf" in schema: - return self._construct_example_from_schema(schema["oneOf"][0]) - if "allOf" in schema: - return self._construct_example_from_schema(schema["allOf"][0]) - - return "example_value" - - -if __name__ == "__main__": - oas_file = ( - "/Users/naamazwerdling/Documents/OASB/policy_validation/airline/oas2.json" - ) - # oas_file = "/Users/naamazwerdling/Documents/OASB/policy_validation/orca/bank/oas.json" - shortfile = "/Users/naamazwerdling/Documents/OASB/policy_validation/airline/s1.json" - # shortfile = "/Users/naamazwerdling/Documents/OASB/policy_validation/orca/bank/short.json" - with open(oas_file) as f: - oas_data = json.load(f) - - summarizer = OASSummarizer(oas_data) - summary = summarizer.summarize() - with open(shortfile, "w") as outfile: - json.dump(summary, outfile, indent=4) - - print(json.dumps(summary, indent=2)) diff --git a/altk/pre_tool/toolguard/toolguard/tool_policy_extractor/prompts/add_examples.txt b/altk/pre_tool/toolguard/toolguard/tool_policy_extractor/prompts/add_examples.txt deleted file mode 100644 index 7ed7c92..0000000 --- a/altk/pre_tool/toolguard/toolguard/tool_policy_extractor/prompts/add_examples.txt +++ /dev/null @@ -1,49 +0,0 @@ -Task: -For the given policy, check if any examples (violating or compliance) are missing. If so, generate additional examples and return only the newly created examples, without modifying or repeating the existing ones. - -Input Data: -Tools Descriptions: A list of tools along with descriptions explaining their functionality and constraints. -Target Tool (ToolX): The specific tool for which relevant policies need to be identified. -Policy: The policy to write additional the examples for with extracted examples from previous stages - -Steps: -1.Identify Missing Examples: -Locate the examples under "examples". -If the "examples" array is missing or empty, generate new examples in natural text, both violating and compliance examples. -If only one type of example is missing (violating or compliance), generate only the missing type. -Do not repeat or modify existing examples—only add new ones. - -2. Generate Violating Examples (if missing): -Provide new violating examples that highlight different ways the policy could be breached. -Cover common mistakes, edge cases, and boundary conditions. -Clearly state why each example violates the policy. - -3. Generate Compliance Examples (if missing): -Provide new compliance examples demonstrating correct adherence to the policy. -Ensure compliance cases clearly illustrate why they meet the policy’s requirements. - -4. Ensure Full Coverage: -If the policy has multiple conditions, generate examples covering each condition separately. -Consider numeric limits, timing constraints, and optional parameters where relevant. - -5. Format the Output in JSON: -Only include newly generated examples. Make sure examples are written in natural text. -Maintain the correct structure: - -Output Format (JSON): -{ - "violating_examples": [ - "Violating example 1", - "Violating example 2, - "Violating example 3" - ], - "compliance_examples": [ - "Compliance example 1", - "Compliance example 2", - "Compliance example 3" - ] -} - -If no new examples are needed, return an empty object: - -{} diff --git a/altk/pre_tool/toolguard/toolguard/tool_policy_extractor/prompts/add_policies.txt b/altk/pre_tool/toolguard/toolguard/tool_policy_extractor/prompts/add_policies.txt deleted file mode 100644 index 917abd4..0000000 --- a/altk/pre_tool/toolguard/toolguard/tool_policy_extractor/prompts/add_policies.txt +++ /dev/null @@ -1,57 +0,0 @@ -Task: -Given the extracted TPTD (Tool Policy Text Description) for the given target tool (ToolX) , verify completeness and identify any additional policies that were missed. Only return newly identified policies. - -Instructions: - -1. Identify Missing Policies: - * Reanalyze the Policy Document to check if any relevant policies for ToolX were overlooked in previous stage. - -2. Extract and format any missing policies following these criteria: - * Must be specific to ToolX. - * Must be actionable and enforceable based on ToolX’s parameters, chat history, and data access. - * Must be self-contained with all necessary details. - * Must have exact verbatim references from the Policy Document (not inferred from tool descriptions). - -3. Validate Policy References: - * Verify that each reference is a verbatim excerpt from the Policy Document and not inferred from tool descriptions. - * If a policy is supported by multiple passages, list them separately in the "references" array. - -4. Output Only New Policies: - * If a policy was already extracted in Stage 1, do not include it in the output. - * If no additional policies are found, return an empty "policies" array. - -Input Format: -Policy Document – A text containing policies, rules, or constraints governing tool usage. -Tools Descriptions – A list of tools with descriptions explaining their functionality and constraints. -Target Tool (ToolX) – The specific tool for which relevant policies need to be identified. -TPTD (Tool Policy Text Description) – A JSON object containing extracted policies from previous stages. - -Output Format (JSON): -{ - "policies": [ - { - "policy_name": "", - "description": "", - "references": [ - "", - "", - ... - ] - }, - ... - { - "policy_name": "", - "description": "", - "references": [ - "" - ... - ] - } - ] -} - - -If no additional relevant policies exist, return: -{ - "policies": [] -} diff --git a/altk/pre_tool/toolguard/toolguard/tool_policy_extractor/prompts/add_references.txt b/altk/pre_tool/toolguard/toolguard/tool_policy_extractor/prompts/add_references.txt deleted file mode 100644 index dc98cd4..0000000 --- a/altk/pre_tool/toolguard/toolguard/tool_policy_extractor/prompts/add_references.txt +++ /dev/null @@ -1,38 +0,0 @@ -Task: -Given a Policy Document, a Target Tool (ToolX), and a specific policy, extract all verbatim supporting references from the Policy Document that directly support the given policy. - -Instructions: -1. Identify Supporting References: - * Search the Policy Document for text that explicitly or implicitly supports the given policy description. - * Extract only the relevant passages that provide direct support for enforcing this policy. - - -2. Ensure Contiguous Verbatim References: - * Each reference must be an exact excerpt from the Policy Document, without alterations. - * If the supportive text appears in multiple non-contiguous segments, extract them separately as distinct items in the list. - * If the relevant text is interrupted by unrelated content, split it into multiple separate references. - * Do not infer meaning beyond the verbatim text. - -3. Output Format: - * If multiple non-contiguous passages support the policy, return them as separate strings in a list. - * If no supporting references are found, return an empty list. - -Input Format: -Policy Document: A text containing policies, rules, or constraints governing tool usage. -Tools Descriptions: A list of tools with descriptions explaining their functionality and constraints. -Target Tool (ToolX): The specific tool for which relevant policies need to be identified. -Policy: The policy for which supporting references need to be extracted. - -Output Format (JSON): -{ - "references": [ - "", - "", - "" - ], -} - -If no relevant references are found, return: -{ - "references": [] -} diff --git a/altk/pre_tool/toolguard/toolguard/tool_policy_extractor/prompts/create_examples.txt b/altk/pre_tool/toolguard/toolguard/tool_policy_extractor/prompts/create_examples.txt deleted file mode 100644 index 4f13f90..0000000 --- a/altk/pre_tool/toolguard/toolguard/tool_policy_extractor/prompts/create_examples.txt +++ /dev/null @@ -1,64 +0,0 @@ -Task Overview: -Your task is to generate in natural text, examples for the given policy. This policy define constraints and rules that must be validated before calling a specific tool, referred to as "ToolX." - -Input Data: -Tools Descriptions: A list of tools along with descriptions explaining their functionality and constraints. -Target Tool (ToolX): The specific tool for which relevant policies need to be identified. -Policy: The policy to write the examples for - -Objective: -For the given policy, generate examples that illustrate both violations (where the policy is not followed) and compliance (where the policy is correctly followed). The goal is to create specific and actionable examples that would be able to be validated. -Ensure you provide a detailed textual description outlining the use case and the conditions that would either violate or comply with the policy in the given example. - -Guidelines for Creating Examples: - -1. Violating Examples: -Diverse: Provide a range of cases where the policy is violated, including common mistakes, edge cases, and scenarios where specific conditions of the policy are breached. -Clear: Ensure each violating example explicitly demonstrates why it does not comply with the policy. Highlight the incorrect aspects. -Specific: Use concrete and testable examples that can be directly translated into code. - -2. Compliance Examples: -Diverse: Include a variety of correct cases, covering different valid ways of adhering to the policy. -Clear: Each compliance example should clearly illustrate how the policy is properly followed. -Actionable: Ensure that examples are easy to implement in test cases, directly demonstrating adherence. - -3. Edge Cases: -Include edge cases for both compliance and violation. Pay special attention to boundary conditions such as numeric limits, optional parameters, or timing constraints. - -4. Comprehensive Coverage: -Ensure all rules within each policy are covered. If a policy contains multiple conditions, provide examples for each possible combination of compliance and violation. - -Output Format (JSON): -{ - "violating_examples": [ - "Violating example 1", - "Violating example 2, - "Violating example 3" - ], - "compliance_examples": [ - "Compliance example 1", - "Compliance example 2", - "Compliance example 3" - ] -} - -Example: - -Target Tool: rent_car -Policy: The user can rent only medium cars in red or small cars in blue or any yellow car - -Output: -{ - "violating_examples": [ - "A user asks to rent a green car", - "The user requests to rent medium blue car", - "Renting a small red car", - "Renting a big red car" - ], - "compliance_examples": [ - "A user asks to rent a red medium car", - "The user requests to rent small blue car", - "Renting a big yellow car", - "Renting a small yellow car" - ] -} diff --git a/altk/pre_tool/toolguard/toolguard/tool_policy_extractor/prompts/create_policy.txt b/altk/pre_tool/toolguard/toolguard/tool_policy_extractor/prompts/create_policy.txt deleted file mode 100644 index 5a0b061..0000000 --- a/altk/pre_tool/toolguard/toolguard/tool_policy_extractor/prompts/create_policy.txt +++ /dev/null @@ -1,57 +0,0 @@ -Task: -Given a Policy Document, Tools Descriptions, and a Target Tool (ToolX), extract and format all relevant policies that should be verified before invoking ToolX. - -Instructions: -1. Identify policies from the Policy Document that are relevant to ToolX. -2. Ensure that extracted policies: - * Are specifically applicable to ToolX (not general for all tools). - * Can be validated before calling ToolX. - * Are actionable, meaning they can be enforced based on ToolX’s parameters, chat history, and data access. - * Are self-contained, meaning the description should contain all necessary details without requiring access to external references. -3. Split policies into the smallest reasonable parts so that: - * Each extracted policy contains only one condition. - * If a policy includes multiple conditions, split it into separate policies whenever possible. -4. Provide exact references to the Policy Document for each extracted policy. - * Locate the exact passage(s) in the Policy Document that fully support the policy statement. - * The corrected reference must be a contiguous, verbatim excerpt from the Policy Document, ensuring it can be precisely highlighted. - * If no single passage fully supports the policy, replace the incorrect reference with multiple distinct references. - * Each reference must be a verbatim excerpt that appears exactly as written in the Policy Document. - * Each supporting passage should be listed separately within the "references" array. - * Make sure to extract only from the Policy Document and not from other input information like the tool description -5. If no policies are applicable, return: -"There are no relevant policies to validate before calling ToolX." - -Input Format: -Policy Document – A text containing policies, rules, or constraints governing tool usage. -Tools Descriptions – A list of tools with descriptions explaining their functionality and constraints. -Target Tool (ToolX) – The specific tool for which relevant policies need to be identified. - -Output Format (JSON): -{ - "policies": [ - { - "policy_name": "", - "description": "", - "references": [ - "", - "", - ... - ], - "iteration_added": 0 - }, - ... - { - "policy_name": "", - "description": "", - "references": [ - "" - ... - ], - "iteration_added": 0 - } - ] -} -If no relevant policies exist, return: -{ - "policies": [] -} diff --git a/altk/pre_tool/toolguard/toolguard/tool_policy_extractor/prompts/examples_reviewer.txt b/altk/pre_tool/toolguard/toolguard/tool_policy_extractor/prompts/examples_reviewer.txt deleted file mode 100644 index 69957b7..0000000 --- a/altk/pre_tool/toolguard/toolguard/tool_policy_extractor/prompts/examples_reviewer.txt +++ /dev/null @@ -1,41 +0,0 @@ -Objective: -You are a reviewer. Your job is to check if a given example correctly shows a good __EXAMPLE_TYPE__ example of a specific rule (policy) for using a tool called ToolX. - -Input: -Tools Descriptions: A list of tools and their functions and constraints. -The Target Tool (ToolX): The tool for which the policy applies. -The Policy Name and Description: Specifies the specific policy under evaluation. -A __EXAMPLE_TYPE__ Example: A natural language use-case for __EXAMPLE_TYPE__ the policy. - -What You Need to Do -For the __EXAMPLE_TYPE__ example, answer these four questions: - * Does this example is indeed a valid __EXAMPLE_TYPE__ example of the policy? or the opposite? (complience or violating example) - * Is the example clear and detailed enough to understand? - * Can someone write a test case based on this example? - * Can a validation function catch this problem before ToolX runs? - -For each question: - * First, explain why you think the example passes or fails. - * Then return a true or false value. - -Format the Output in JSON: -``` -{ - "is_valid": { - "justification": "Explain whether this example actually is a __EXAMPLE_TYPE__ example , and why.", - "value": true or false - }, - "clarity_and_detail": { - "justification": "Explain if the example is clear, specific, and detailed enough to understand the problem.", - "value": true or false - }, - "actionability": { - "justification": "Explain if someone could write a test case based on this example.", - "value": true or false - }, - "verifiability": { - "justification": "Explain if a validation function could detect this problem before ToolX is used.", - "value": true or false - } -} -``` diff --git a/altk/pre_tool/toolguard/toolguard/tool_policy_extractor/prompts/fix_example.txt b/altk/pre_tool/toolguard/toolguard/tool_policy_extractor/prompts/fix_example.txt deleted file mode 100644 index ea910e1..0000000 --- a/altk/pre_tool/toolguard/toolguard/tool_policy_extractor/prompts/fix_example.txt +++ /dev/null @@ -1,30 +0,0 @@ -Objective: -Your task is to revise and improve a single __EXAMPLE_TYPE__ example for a given policy. The goal is to ensure that the example is clear, precise, actionable, and suitable for automated validation or unit testing. - -Task: Fix and Improve an Example -Given: -* Policy Document: A text outlining rules, constraints, or usage policies for various tools. -* Tools Descriptions: A list of tools along with explanations of their capabilities and constraints. -* Target Tool (ToolX): The specific tool for which the policy is being analyzed. -* Policy Name: The name of the policy -* Policy Description: The description of the policy -* Example: The __EXAMPLE_TYPE__ example to fix - - -Update the example as follows: - -1. Clarity and Specificity - * Clearly show how and why the example violates or complies with the policy. - * Replace vague or ambiguous descriptions with precise actions, conditions, and parameters. - * Ensure the language is unambiguous and easy to understand. - -2. Actionability for Testing - * The revised example must be concrete and testable. - * Include enough detail so that a developer can implement a unit test or validation rule from it without referring back to the full policy document. - * Examples should include relevant inputs, conditions, and outcomes when applicable. - - -Output Format (JSON): -{ - "revised_example": "The updated example text, or the original text if no changes were necessary." -} diff --git a/altk/pre_tool/toolguard/toolguard/tool_policy_extractor/prompts/functions.txt b/altk/pre_tool/toolguard/toolguard/tool_policy_extractor/prompts/functions.txt deleted file mode 100644 index 3b65b39..0000000 --- a/altk/pre_tool/toolguard/toolguard/tool_policy_extractor/prompts/functions.txt +++ /dev/null @@ -1,56 +0,0 @@ -Given an OpenAPI Specification (OAS) file and a selected operation (such as GET /items/{item_id}), follow these steps to create a Python function and sample input/output examples. - -Step 1: Identify Operation Components -From the operation, extract: -* Operation ID (or generate a function name using HTTP method + path) -* HTTP method and path -* Parameters: - * Path, query, header, and cookie parameters -* Request body (if defined) -* Response schema (focus on HTTP 200 or default response) - -Step 2: Write the Python Function Signature -Create a Python function named after the operation. Follow these rules: -* Use snake_case for function names. -* Each parameter should become a function argument: - * Required parameters have no default. - * Optional parameters have default values (e.g., None, False, ""). -* Use basic type hints: str, int, float, bool, dict, List[str], etc. - - -Step 3: Generate 3 Diverse Parameter Examples -Write three different function calls that showcase a variety of: -* Simple vs. complex values -* Optional vs. required parameters -* Empty, large, or edge-case values -* Realistic business logic (e.g., filters, limits) - -Example variations: -* Empty query strings -* Lists or dicts as input -* Special characters -* Boundary numbers - -Step 4: Generate 3 Diverse Return Value Examples -Create three realistic return values in Python dict format, based on the operation's response schema. The examples should vary in: -* Field presence (e.g., missing optional fields) -* Field types (e.g., one returns a list, one returns a dict) -* Structure (nested vs. flat) -* Value range (small, large, edge cases) - -Step 5: Return in JSON Format -Return the following in a single JSON object: - -{ - "function_signature": "Python function definition as a string", - "input_examples": [ - "Function call as string #1", - "Function call as string #2", - "Function call as string #3" - ], - "output_examples": [ - { "example_response_1": "..." }, - { "example_response_2": "..." }, - { "example_response_3": "..." } - ] -} diff --git a/altk/pre_tool/toolguard/toolguard/tool_policy_extractor/prompts/merge.txt b/altk/pre_tool/toolguard/toolguard/tool_policy_extractor/prompts/merge.txt deleted file mode 100644 index 8756b7b..0000000 --- a/altk/pre_tool/toolguard/toolguard/tool_policy_extractor/prompts/merge.txt +++ /dev/null @@ -1,51 +0,0 @@ -Task: -Given a Policy Document, Tools Descriptions, a Target Tool (ToolX), and a TPTD JSON object containing extracted policies, refine the policies by: - -1. Merging Identical Policies: -If two or more policies have identical descriptions and references, consolidate them into a single policy. -Preserve all references from the merged policies. - -2. Merging Policies with Logical OR Relationships: -If multiple policies describe conditions where at least one must be satisfied (i.e., connected by an OR relationship), merge them into a single policy. -Ensure the combined policy reflects the OR logic and retains all original references. - -3. Ensuring Clarity and Enforceability: -Each extracted policy should contain only a single actionable condition. -Maintain clear references for each policy to ensure traceability. -Retain the iteration_added field and increment it by 1 for any newly created policies. - -Input Format: -Policy Document – A text containing policies, rules, or constraints governing tool usage. -Tools Descriptions – A list of tools with descriptions explaining their functionality and constraints. -Target Tool (ToolX) – The specific tool for which relevant policies need to be identified. -TPTD (Tool Policy Text Description) – A JSON object containing extracted policies from previous stages. - -Output Format (JSON): -{ - "policies": [ - { - "policy_name": "", - "description": "", - "references": [ - "", - "", - ... - ] - }, - ... - { - "policy_name": "", - "description": "", - "references": [ - "" - ... - ] - } - ] -} - - -If no additional relevant policies exist, return: -{ - "policies": [] -} diff --git a/altk/pre_tool/toolguard/toolguard/tool_policy_extractor/prompts/merge_and_split.txt b/altk/pre_tool/toolguard/toolguard/tool_policy_extractor/prompts/merge_and_split.txt deleted file mode 100644 index 0ceb541..0000000 --- a/altk/pre_tool/toolguard/toolguard/tool_policy_extractor/prompts/merge_and_split.txt +++ /dev/null @@ -1,52 +0,0 @@ -Task: -Given a Policy Document, Tools Descriptions, a Target Tool (ToolX), and a TPTD JSON object containing extracted policies, refine the policies by: - -1. Merging Identical Policies: -If two or more policies have identical descriptions and references, consolidate them into a single policy. -Preserve all references from the merged policies. - -2. Breaking Down Multi-Condition Policies: -If a policy contains multiple conditions, separate them into distinct policies whenever feasible. -If a policy requires both Condition A and Condition B to be met, create two separate policies. -If a policy allows either Condition A or Condition B to apply, it should remain as a single policy. - -3. Ensuring Clarity and Enforceability: -Each extracted policy should contain only a single actionable condition. -Maintain clear references for each policy to ensure traceability. -Retain the iteration_added field and increment it by 1 for any newly created policies. - -Input Format: -Policy Document – A text containing policies, rules, or constraints governing tool usage. -Tools Descriptions – A list of tools with descriptions explaining their functionality and constraints. -Target Tool (ToolX) – The specific tool for which relevant policies need to be identified. -TPTD (Tool Policy Text Description) – A JSON object containing extracted policies from previous stages. - -Output Format (JSON): -{ - "policies": [ - { - "policy_name": "", - "description": "", - "references": [ - "", - "", - ... - ] - }, - ... - { - "policy_name": "", - "description": "", - "references": [ - "" - ... - ] - } - ] -} - - -If no additional relevant policies exist, return: -{ - "policies": [] -} diff --git a/altk/pre_tool/toolguard/toolguard/tool_policy_extractor/prompts/merge_examples.txt b/altk/pre_tool/toolguard/toolguard/tool_policy_extractor/prompts/merge_examples.txt deleted file mode 100644 index 9f7dc10..0000000 --- a/altk/pre_tool/toolguard/toolguard/tool_policy_extractor/prompts/merge_examples.txt +++ /dev/null @@ -1,58 +0,0 @@ -Task: Deduplicate Examples in Policy Evaluation Context - -You are given the following inputs: -* Policy Document: A text outlining rules, constraints, or usage policies for various tools. -* Tools Descriptions: A list of tools along with explanations of their capabilities and constraints. -* Target Tool (ToolX): The specific tool for which the policy is being analyzed. -* Policy Name: The name of the policy -* Policy Description: The description of the policy -* Violating Examples: A list of example scenarios that violate the policy. -* Compliance Examples: A list of example scenarios that are compliant with the policy. - -Your task is to refine both sets of examples (violating and compliance) by merging identical or semantically duplicate examples: -Instructions: -1. Review the examples in each set independently (i.e., review violating examples separately from compliance examples). -2. Identify duplicates: - * If two or more examples describe the same scenario (identical or paraphrased), merge them into one clear and representative example. - * Examples are considered duplicates if they involve the same conditions and policy points being violated or followed. -3. Do not merge examples that: - * Refer to different scenarios, distinct edge cases, or variations of policy application. - * Provide unique value in understanding how the policy is applied or violated. - -Output JSON Format: -{ - "violating_examples": [ ... ], - "compliance_examples": [ ... ] -} - -Example: -Input (Before Merging): -{ - "violating_examples": [ - "ToolX is used without an authentication token.", - "A request is sent to ToolX without the required authentication token.", - "Input exceeds the character limit for ToolX.", - "ToolX request has input of more than 500 characters.", - "A restricted keyword is used in ToolX query." - ], - "compliance_examples": [ - "A valid authentication token is used when calling ToolX.", - "The input for ToolX is within the allowed character limit.", - "No restricted keywords are present in the ToolX request.", - "ToolX is accessed with proper credentials and policy-compliant query." - ] -} - -Output (After Merging): -{ - "violating_examples": [ - "A request is sent to ToolX without the required authentication token, violating the policy that mandates authentication.", - "The input parameters for ToolX exceed the allowed 500-character limit, which violates input size constraints.", - "A restricted keyword is used in the ToolX request, which the policy explicitly disallows." - ], - "compliance_examples": [ - "A user provides a valid authentication token when calling ToolX, following the authentication requirement.", - "The input parameters for ToolX are within the allowed character limit, meeting the policy criteria.", - "ToolX is accessed with valid credentials and a policy-compliant query containing no restricted keywords." - ] -} diff --git a/altk/pre_tool/toolguard/toolguard/tool_policy_extractor/prompts/policy_reviewer.txt b/altk/pre_tool/toolguard/toolguard/tool_policy_extractor/prompts/policy_reviewer.txt deleted file mode 100644 index 2f2daa4..0000000 --- a/altk/pre_tool/toolguard/toolguard/tool_policy_extractor/prompts/policy_reviewer.txt +++ /dev/null @@ -1,39 +0,0 @@ -Task: -Evaluate and refine the given policy to ensure it meets standards for quality, relevance, and applicability. Your objectives are: -Determine if the policy is relevant and actionable for ToolX. -Verify whether the policy can be validated before invoking ToolX whenever possible. -Identify gaps in the policy and suggest improvements to make it self-contained if necessary. - -Input: -Policy Document: A text outlining policies, rules, or constraints related to tool usage. -Tool Descriptions: A list detailing tools, their functionalities, and constraints. -Target Tool (ToolX): The specific tool for which relevant policies must be identified. -Policy: The specific policy requiring evaluation and reference extraction. - -Evaluation Criteria: -is_relevant: Does the policy specifically apply to ToolX? -is_tool_specific: Is the policy tailored to ToolX, or is it too broad (e.g., applicable to all tools)? -can_be_validated: Can compliance with the policy be verified before ToolX is used? -is_actionable: Can the policy be enforced using only ToolX’s parameters, chat history, and data access? -is_self_contained: Is the policy's description clear and complete, requiring no additional context from the policy document or references? If not, suggest an improved version. - -Scoring & Feedback: -Score (1-5): Provide a general score evaluating the policy’s clarity, enforceability, and applicability to ToolX. -Comments: Justify the assigned score and explain any deficiencies. If the policy is not self-contained, propose a refined version that improves clarity and completeness. - -Output JSON Format: -{ - "policy_name": "", - "description": "", - "references": [ - "" - ], - "is_relevant": , - "is_tool_specific": , - "can_be_validated": , - "is_actionable": , - "is_self_contained": , - "alternative_description": , - "comments": "", - "score": -} diff --git a/altk/pre_tool/toolguard/toolguard/tool_policy_extractor/prompts/split.txt b/altk/pre_tool/toolguard/toolguard/tool_policy_extractor/prompts/split.txt deleted file mode 100644 index 199ee8c..0000000 --- a/altk/pre_tool/toolguard/toolguard/tool_policy_extractor/prompts/split.txt +++ /dev/null @@ -1,48 +0,0 @@ -Task: -Given a Policy Document, Tools Descriptions, a Target Tool (ToolX), and a TPTD JSON object containing extracted policies, refine the policies by: - -1. Breaking Down Multi-Condition Policies: -If a policy contains multiple conditions, separate them into distinct policies whenever feasible. -If a policy requires both Condition A and Condition B to be met, create two separate policies. -If a policy allows either Condition A or Condition B to apply, it should remain as a single policy. - -2. Ensuring Clarity and Enforceability: -Each extracted policy should contain only a single actionable condition. -Maintain clear references for each policy to ensure traceability. -Retain the iteration_added field and increment it by 1 for any newly created policies. - -Input Format: -Policy Document – A text containing policies, rules, or constraints governing tool usage. -Tools Descriptions – A list of tools with descriptions explaining their functionality and constraints. -Target Tool (ToolX) – The specific tool for which relevant policies need to be identified. -TPTD (Tool Policy Text Description) – A JSON object containing extracted policies from previous stages. - -Output Format (JSON): -{ - "policies": [ - { - "policy_name": "", - "description": "", - "references": [ - "", - "", - ... - ] - }, - ... - { - "policy_name": "", - "description": "", - "references": [ - "" - ... - ] - } - ] -} - - -If no additional relevant policies exist, return: -{ - "policies": [] -} diff --git a/altk/pre_tool/toolguard/toolguard/tool_policy_extractor/text_tool_policy_generator.py b/altk/pre_tool/toolguard/toolguard/tool_policy_extractor/text_tool_policy_generator.py deleted file mode 100644 index de28c0f..0000000 --- a/altk/pre_tool/toolguard/toolguard/tool_policy_extractor/text_tool_policy_generator.py +++ /dev/null @@ -1,467 +0,0 @@ -import asyncio -import json -import os -import inspect -import sys -from typing import Any, Callable, List, Optional - -from langchain_core.tools import BaseTool -from pydantic import BaseModel - -from altk.pre_tool.toolguard.toolguard.data_types import load_tool_policy, ToolPolicy -from altk.pre_tool.toolguard.toolguard.llm.i_tg_llm import I_TG_LLM -from altk.pre_tool.toolguard.toolguard.tool_policy_extractor.utils import ( - read_prompt_file, - generate_messages, - save_output, - find_mismatched_references, -) - - -class ToolInfo(BaseModel): - name: str - description: str - parameters: Any - signature: str - full_description: str - - @classmethod - def from_function(cls, fn: Callable) -> "ToolInfo": - # Assumes @tool decorator from langchain https://python.langchain.com/docs/how_to/custom_tools/ - # or a plain function with doc string - def doc_summary(doc: str): - paragraphs = [p.strip() for p in doc.split("\n\n") if p.strip()] - return paragraphs[0] if paragraphs else "" - - fn_name = fn.name if hasattr(fn, "name") else fn.__name__ - sig = fn_name + str(get_tool_signature(fn)) - full_desc = ( - fn.description - if hasattr(fn, "description") - else fn.__doc__.strip() - if fn.__doc__ - else (inspect.getdoc(fn) or "") - ) - return cls( - name=fn_name, - description=doc_summary(full_desc), - full_description=full_desc, - parameters=fn.args_schema.model_json_schema() - if hasattr(fn, "args_schema") - else inspect.getdoc(fn), - signature=sig, - ) - - -def get_tool_signature(obj): - if inspect.isfunction(obj): - return inspect.signature(obj) - if hasattr(obj, "func") and inspect.isfunction(obj.func): - return inspect.signature(obj.func) - if hasattr(obj, "args_schema"): - schema = obj.args_schema - fields = schema.model_fields - params = ", ".join( - f"{name}: {field.annotation.__name__ if hasattr(field.annotation, '__name__') else field.annotation}" - for name, field in fields.items() - ) - return f"({params})" - return None - - -def extract_functions(file_path: str) -> List[Callable]: - import importlib.util - import inspect - - module_name = os.path.splitext(os.path.basename(file_path))[0] - - # Add project root to sys.path - project_root = os.path.abspath( - os.path.join(file_path, "..", "..") - ) # Adjust as needed - if project_root not in sys.path: - sys.path.insert(0, project_root) - - spec = importlib.util.spec_from_file_location(module_name, file_path) - if not spec or not spec.loader: - raise ImportError(f"Could not load module from {file_path}") - - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - tools = [] - for name, obj in inspect.getmembers(module): # noqa: B007 - if isinstance(obj, BaseTool): - if hasattr(obj, "name") and hasattr(obj, "args_schema"): - tools.append(obj) - else: - if callable(obj) and obj.__name__ != "tool": - tools.append(obj) - - return tools - - -class TextToolPolicyGenerator: - def __init__( - self, llm: I_TG_LLM, policy_document: str, tools: List[ToolInfo], out_dir: str - ) -> None: - self.llm = llm - self.policy_document = policy_document - self.tools_descriptions = {tool.name: tool.description for tool in tools} - self.tools_details = {tool.name: tool for tool in tools} - self.out_dir = out_dir - - async def generate_minimal_policy(self, tool_name: str) -> dict: - tptd = await self.create_policy(tool_name) - tptd = await self.example_creator(tool_name, tptd) - return tptd - - async def generate_policy(self, tool_name: str) -> dict: - tptd = await self.create_policy(tool_name) - for i in range(3): - tptd = await self.add_policies(tool_name, tptd, i) - tptd = await self.split(tool_name, tptd) - tptd = await self.merge(tool_name, tptd) - tptd = await self.review_policy(tool_name, tptd) - tptd = await self.add_references(tool_name, tptd) - tptd = await self.reference_correctness(tool_name, tptd) - tptd = await self.example_creator(tool_name, tptd) - for i in range(5): # FIXME - tptd = await self.add_examples(tool_name, tptd, i) - tptd = await self.merge_examples(tool_name, tptd) - # tptd = self.fix_examples(tool_name, tptd) - tptd = await self.review_examples(tool_name, tptd) - return tptd - - async def create_policy(self, tool_name: str) -> dict: - print("policy_creator_node") - system_prompt = read_prompt_file("create_policy") - system_prompt = system_prompt.replace("ToolX", tool_name) - user_content = f"Policy Document:{self.policy_document}\nTools Descriptions:{json.dumps(self.tools_descriptions)}\nTarget Tool:{self.tools_details[tool_name].model_dump_json()}\n" - tptd = await self.llm.chat_json(generate_messages(system_prompt, user_content)) - save_output(self.out_dir, f"{tool_name}.json", tptd) - return tptd - - async def add_policies( - self, tool_name: str, tptd: dict, iteration: int = 0 - ) -> dict: - print("add_policy") - system_prompt = read_prompt_file("add_policies") - user_content = f"Policy Document:{self.policy_document}\nTools Descriptions:{json.dumps(self.tools_descriptions)}\nTarget Tool:{self.tools_details[tool_name].model_dump_json()}\nTPTD: {json.dumps(tptd)}" - response = await self.llm.chat_json( - generate_messages(system_prompt, user_content) - ) - - policies = ( - response["additionalProperties"]["policies"] - if "additionalProperties" in response and "policies" not in response - else response["policies"] - ) - - for policy in policies: - # for policy in response["policies"]: - policy["iteration"] = iteration - tptd["policies"].append(policy) - - save_output(self.out_dir, f"{tool_name}_ADD_{iteration}.json", tptd) - return tptd - - async def split(self, tool_name, tptd: dict) -> dict: - # todo: consider addition step to split policy by policy and not overall - print("split") - system_prompt = read_prompt_file("split") - user_content = f"Policy Document:{self.policy_document}\nTools Descriptions:{json.dumps(self.tools_descriptions)}\nTarget Tool:{self.tools_details[tool_name].model_dump_json()}\nTPTD: {json.dumps(tptd)}" - tptd = await self.llm.chat_json(generate_messages(system_prompt, user_content)) - save_output(self.out_dir, f"{tool_name}_split.json", tptd) - return tptd - - async def merge(self, tool_name, tptd: dict) -> dict: - # todo: consider addition step to split policy by policy and not overall - print("merge") - system_prompt = read_prompt_file("merge") - user_content = f"Policy Document:{self.policy_document}\nTools Descriptions:{json.dumps(self.tools_descriptions)}\nTarget Tool:{self.tools_details[tool_name].model_dump_json()}\nTPTD: {json.dumps(tptd)}" - tptd = await self.llm.chat_json(generate_messages(system_prompt, user_content)) - - save_output(self.out_dir, f"{tool_name}_merge.json", tptd) - return tptd - - def move2archive(self, reviews) -> (bool, str): - comments = "" - num = len(reviews) - if num == 0: - return False - counts = { - "is_relevant": 0, - "is_tool_specific": 0, - "can_be_validated": 0, - "is_actionable": 0, - } - - for r in reviews: - print( - f"{r['is_relevant'] if 'is_relevant' in r else ''}\t{r['is_tool_specific'] if 'is_tool_specific' in r else ''}\t{r['can_be_validated'] if 'can_be_validated' in r else ''}\t{r['is_actionable'] if 'is_actionable' in r else ''}\t{r['is_self_contained'] if 'is_self_contained' in r else ''}\t{r['score'] if 'score' in r else ''}\t" - ) - - counts["is_relevant"] += r["is_relevant"] if "is_relevant" in r else 0 - counts["is_tool_specific"] += ( - r["is_tool_specific"] if "is_tool_specific" in r else 0 - ) - counts["can_be_validated"] += ( - r["can_be_validated"] if "can_be_validated" in r else 0 - ) - counts["is_actionable"] += r["is_actionable"] if "is_actionable" in r else 0 - - if not all( - e in r - for e in [ - "is_relevant", - "is_tool_specific", - "can_be_validated", - "is_actionable", - ] - ) or not ( - r["is_relevant"] - and r["is_tool_specific"] - and r["can_be_validated"] - and r["is_actionable"] - ): - comments += r["comments"] + "\n" - - return not (all(float(counts[key]) / num > 0.5 for key in counts)), comments - - async def review_policy(self, tool_name, tptd) -> dict: - print("review_policy") - system_prompt = read_prompt_file("policy_reviewer") - newTPTD = {"policies": []} - - if "policies" not in tptd: - tptd["policies"] = [] - - for policy in tptd["policies"]: - reviews = [] - for _iteration in range(5): - user_content = f"Policy Document:{self.policy_document}\nTools Descriptions:{json.dumps(self.tools_descriptions)}\nTarget Tool:{json.dumps(self.tools_descriptions[tool_name])}\npolicy: {json.dumps(policy)}" - response = await self.llm.chat_json( - generate_messages(system_prompt, user_content) - ) - if "is_self_contained" in response: - is_self_contained = response["is_self_contained"] - if not is_self_contained: - if "alternative_description" in response: - policy["description"] = response["alternative_description"] - else: - print( - "Error: review is_self_contained is false but no alternative_description." - ) - else: - print("Error: review did not provide is_self_contained.") - reviews.append(response) - archive, comments = self.move2archive(reviews) - print(archive) - if archive: - if "archive" not in newTPTD: - newTPTD["archive"] = [] - policy["comments"] = comments - newTPTD["archive"].append(policy) - else: - newTPTD["policies"].append(policy) - save_output(self.out_dir, f"{tool_name}_rev.json", newTPTD) - return newTPTD - - async def add_references(self, tool_name: str, tptd: dict) -> dict: - print("add_ref") - system_prompt = read_prompt_file("add_references") - # remove old refs (used to help avoid duplications) - for policy in tptd["policies"]: - policy["references"] = [] - user_content = f"Policy Document:{self.policy_document}\nTools Descriptions:{json.dumps(self.tools_descriptions)}\nTarget Tool:{self.tools_details[tool_name].model_dump_json()}\npolicy: {json.dumps(policy)}" - response = await self.llm.chat_json( - generate_messages(system_prompt, user_content) - ) - if "references" in response: - policy["references"] = response["references"] - else: - print("Error! no references in response") - print(response) - - save_output(self.out_dir, f"{tool_name}_ref.json", tptd) - return tptd - - async def reference_correctness(self, tool_name: str, tptd: dict) -> dict: - print("reference_correctness") - tptd, unmatched_policies = find_mismatched_references( - self.policy_document, tptd - ) - save_output(self.out_dir, f"{tool_name}_ref_orig_.json", unmatched_policies) - save_output(self.out_dir, f"{tool_name}_ref_correction_.json", tptd) - return tptd - - async def example_creator(self, tool_name: str, tptd: dict) -> dict: - print("example_creator") - system_prompt = read_prompt_file("create_examples") - system_prompt = system_prompt.replace("ToolX", tool_name) - - for policy in tptd["policies"]: - # user_content = f"Policy Document:{state['policy_text']}\nTools Descriptions:{json.dumps(state['tools'])}\nTarget Tool:{json.dumps(state['target_tool_description'])}\nPolicy:{policy}" - user_content = f"Tools Descriptions:{json.dumps(self.tools_descriptions)}\nTarget Tool:{self.tools_details[tool_name].model_dump_json()}\nPolicy:{policy}" - - response = await self.llm.chat_json( - generate_messages(system_prompt, user_content) - ) - if "violating_examples" in response: - policy["violating_examples"] = response["violating_examples"] - - if "compliance_examples" in response: - policy["compliance_examples"] = response["compliance_examples"] - - save_output(self.out_dir, f"{tool_name}_examples.json", tptd) - return tptd - - async def add_examples(self, tool_name: str, tptd: dict, iteration: int) -> dict: - print("add_examples") - system_prompt = read_prompt_file("add_examples") - system_prompt = system_prompt.replace("ToolX", tool_name) - for policy in tptd["policies"]: - # user_content = f"Policy Document:{state['policy_text']}\nTools Descriptions:{json.dumps(state['tools'])}\nTarget Tool:{json.dumps(state['target_tool_description'])}\nPolicy:{policy}" - user_content = f"Tools Descriptions:{json.dumps(self.tools_descriptions)}\nTarget Tool:{self.tools_details[tool_name].model_dump_json()}\nPolicy:{policy}" - response = await self.llm.chat_json( - generate_messages(system_prompt, user_content) - ) - if "violating_examples" in response: - for vexample in response["violating_examples"]: - # vexample["iteration"] = state["iteration"] - if "violating_examples" not in policy: - policy["violating_examples"] = [] - policy["violating_examples"].append(vexample) - if "compliance_examples" in response: - for cexample in response["compliance_examples"]: - if "compliance_examples" not in policy: - policy["compliance_examples"] = [] - # cexample["iteration"] = state["iteration"] - policy["compliance_examples"].append(cexample) - - save_output(self.out_dir, f"{tool_name}_ADD_examples{iteration}.json", tptd) - return tptd - - async def merge_examples(self, tool_name: str, tptd: dict) -> dict: - print("merge_examples") - system_prompt = read_prompt_file("merge_examples") - system_prompt = system_prompt.replace("ToolX", tool_name) - for policy in tptd["policies"]: - # user_content = f"Policy Document:{state['policy_text']}\nTools Descriptions:{json.dumps(state['tools'])}\nTarget Tool:{json.dumps(state['target_tool_description'])}\nPolicy Name:{policy['policy_name']}\nPolicy Description:{policy['description']}" - user_content = f"Tools Descriptions:{json.dumps(self.tools_descriptions)}\nTarget Tool:{self.tools_details[tool_name].model_dump_json()}\nPolicy Name:{policy['policy_name']}\nPolicy Description:{policy['description']}" - user_content += f"\n\nViolating Examples: {policy['violating_examples']}" - user_content += f"\n\nCompliance Examples: {policy['compliance_examples']}" - response = await self.llm.chat_json( - generate_messages(system_prompt, user_content) - ) - policy["violating_examples"] = response["violating_examples"] - policy["compliance_examples"] = response["compliance_examples"] - - save_output(self.out_dir, f"{tool_name}_merge_examples.json", tptd) - return tptd - - async def fix_examples(self, tool_name: str, tptd: dict) -> dict: - print("fix_examples") - orig_prompt = read_prompt_file("fix_example") - for policy in tptd["policies"]: - for etype in ["violating", "compliance"]: - fixed_examples = [] - for example in policy[etype + "_examples"]: - system_prompt = orig_prompt.replace("ToolX", tool_name) - system_prompt = system_prompt.replace("__EXAMPLE_TYPE__", "") - - # user_content = f"Policy Document:{state['policy_text']}\nTools Descriptions:{json.dumps(state['tools'])}\nTarget Tool:{json.dumps(state['target_tool_description'])}\nPolicy Name:{policy['policy_name']}\nPolicy Description:{policy['description']}\nExample:{example}" - user_content = f"Tools Descriptions:{json.dumps(self.tools_descriptions)}\nTarget Tool:{self.tools_details[tool_name].model_dump_json()}\nPolicy Name:{policy['policy_name']}\nPolicy Description:{policy['description']}\nExample:{example}" - - response = await self.llm.chat_json( - generate_messages(system_prompt, user_content) - ) - fixed_examples.append(response["revised_example"]) - policy[etype + "_examples"] = fixed_examples - - save_output(self.out_dir, f"{tool_name}_fix_examples.json", tptd) - return tptd - - # todo: change to revew examples, write prompts - async def review_examples(self, tool_name: str, tptd: dict) -> dict: - print("review_examples") - system_prompt = read_prompt_file("examples_reviewer") - for policy in tptd["policies"]: - print(policy["policy_name"]) - for etype in ["violating", "compliance"]: - print(etype) - passed_examples = [] - for example in policy[etype + "_examples"]: - print(example) - reviews = [] - for _iteration in range(5): - # user_content = f"Policy Document:{state['policy_text']}\nTools Descriptions:{json.dumps(state['tools'])}\nTarget Tool:{json.dumps(state['target_tool_description'])}\nPolicy Name:{policy['policy_name']}\nPolicy Description:{policy['description']}\nExample:{example}" - user_content = f"Tools Descriptions:{json.dumps(self.tools_descriptions)}\nTarget Tool:{self.tools_details[tool_name].model_dump_json()}\nPolicy Name:{policy['policy_name']}\nPolicy Description:{policy['description']}\nExample:{example}" - response = await self.llm.chat_json( - generate_messages(system_prompt, user_content) - ) - reviews.append(response) - keep = self.keep_example(reviews) - if keep: - passed_examples.append(example) - - policy[etype + "_examples"] = passed_examples - - save_output(self.out_dir, f"{tool_name}_example_rev.json", tptd) - return tptd - - def keep_example(self, reviews) -> bool: - bads = 0 - totals = 0 - for r in reviews: - for vals in r.values(): - totals += 1 - if "value" not in vals: - print(reviews) - elif not vals["value"]: - bads += 1 - if bads / totals > 0.8: - return False - return True - - -async def extract_policies( - policy_text: str, - tools: List[ToolInfo], - step1_output_dir: str, - llm: I_TG_LLM, - tools_shortlist: Optional[List[str]] = None, - short=False, -) -> List[ToolPolicy]: - if not os.path.isdir(step1_output_dir): - os.makedirs(step1_output_dir) - - process_dir = os.path.join(step1_output_dir, "process") - if not os.path.isdir(process_dir): - os.makedirs(process_dir) - output_tool_policies = [] - tpg = TextToolPolicyGenerator(llm, policy_text, tools, process_dir) - - async def do_one_tool(tool_name): - if short: - final_output = await tpg.generate_minimal_policy(tool_name) - else: - final_output = await tpg.generate_policy(tool_name) - - with open(os.path.join(step1_output_dir, tool_name + ".json"), "w") as outfile1: - outfile1.write(json.dumps(final_output, indent=2)) - output_tool_policies.append( - load_tool_policy( - os.path.join(step1_output_dir, tool_name + ".json"), tool_name - ) - ) - - await asyncio.gather( - *[ - do_one_tool(tool.name) - for tool in tools - if ((tools_shortlist is None) or (tool.name in tools_shortlist)) - ] - ) - print("All tools done") - return output_tool_policies diff --git a/altk/pre_tool/toolguard/toolguard/tool_policy_extractor/utils.py b/altk/pre_tool/toolguard/toolguard/tool_policy_extractor/utils.py deleted file mode 100644 index b2bd518..0000000 --- a/altk/pre_tool/toolguard/toolguard/tool_policy_extractor/utils.py +++ /dev/null @@ -1,96 +0,0 @@ -from typing import Dict, Any -import os - -import json -from typing import List - - -def read_prompt_file(filename: str) -> str: - with open( - os.path.join(os.path.dirname(__file__), "prompts", filename + ".txt"), "r" - ) as f: - return f.read() - - -def generate_messages(system_prompt: str, user_content: str) -> List[Dict[str, str]]: - return [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_content}, - ] - - -def save_output(outdir: str, filename: str, content: Any): - with open(os.path.join(outdir, filename), "w") as outfile: - json.dump(content, outfile, indent=4) - - -def normalize_text(text): - """Normalize text by removing punctuation, converting to lowercase, and standardizing spaces.""" - # return re.sub(r'\s+', ' ', re.sub(r'[^a-zA-Z0-9\s]', '', text)).strip().lower() - return text.lower() - - -def split_reference_if_both_parts_exist(reference, policy_text): - words = reference.split() - for split_point in range(1, len(words)): - part1 = " ".join(words[:split_point]) - part2 = " ".join(words[split_point:]) - - normalized_part1 = normalize_text(part1) - normalized_part2 = normalize_text(part2) - normalized_policy_text = normalize_text(policy_text) - - if ( - normalized_part1 in normalized_policy_text - and normalized_part2 in normalized_policy_text - ): - start_idx1 = normalized_policy_text.find(normalized_part1) - end_idx1 = start_idx1 + len(part1) - start_idx2 = normalized_policy_text.find(normalized_part2) - end_idx2 = start_idx2 + len(part2) - return [policy_text[start_idx1:end_idx1], policy_text[start_idx2:end_idx2]] - return None - - -def find_mismatched_references(policy_text, policy_json): - corrections = json.loads(json.dumps(policy_json)) - unmatched_policies = [] - if isinstance(corrections["policies"], str): - return corrections, unmatched_policies - - normalized_policy_text = normalize_text(policy_text) - - for policy in corrections["policies"]: - corrected_references = [] - has_unmatched = False - - for reference in policy["references"]: - normalized_ref = normalize_text(reference) - # if normalized ref in policy doc- just copy the original - if normalized_ref in normalized_policy_text: - start_idx = normalized_policy_text.find(normalized_ref) - end_idx = start_idx + len(reference) - corrected_references.append(policy_text[start_idx:end_idx]) - else: - # close_match = get_close_matches(normalized_ref, [normalized_policy_text], n=1, cutoff=0.9) - # if close_match: - # start_idx = normalized_policy_text.find(close_match[0]) - # end_idx = start_idx + len(close_match[0]) - # corrected_references.append(policy_text[start_idx:end_idx]) - # else: - split_segments = split_reference_if_both_parts_exist( - reference, policy_text - ) - if split_segments: - corrected_references.extend(split_segments) - else: - corrected_references.append( - reference - ) # Keep original if no match found - has_unmatched = True - - policy["references"] = corrected_references - if has_unmatched: - unmatched_policies.append(policy["policy_name"]) - - return corrections, unmatched_policies diff --git a/altk/pre_tool/toolguard/toolguard_code_component.py b/altk/pre_tool/toolguard/toolguard_code_component.py new file mode 100644 index 0000000..6795079 --- /dev/null +++ b/altk/pre_tool/toolguard/toolguard_code_component.py @@ -0,0 +1,104 @@ +import logging +from typing import Any, Callable, Dict, List, cast +from enum import Enum +from pydantic import BaseModel, Field +from typing import Set +from langchain_core.tools import BaseTool + +from altk.core.toolkit import ComponentConfig, ComponentInput, AgentPhase, ComponentBase +from toolguard import generate_guards_from_specs, ToolGuardSpec, ToolGuardsCodeGenerationResult, load_toolguards +from toolguard.runtime import IToolInvoker, ToolGuardsCodeGenerationResult + +from altk.pre_tool.toolguard.llm_client import TG_LLMEval + +logger = logging.getLogger(__name__) + +class ToolGuardCodeComponentConfig(ComponentConfig): + pass + +class ToolGuardCodeBuildInput(ComponentInput): + tools: List[Callable] | List[BaseTool] | str + toolguard_specs: List[ToolGuardSpec] + out_dir: str + +ToolGuardBuildOutput = ToolGuardsCodeGenerationResult + +class ToolGuardCodeRunInput(ComponentInput): + generated_guard_dir: str + tool_name: str = Field(description="Tool name") + tool_args: Dict[str, Any] = Field(default={}, description="Tool arguments") + tool_invoker: IToolInvoker + + model_config = { + "arbitrary_types_allowed": True + } + +class ViolationLevel(Enum): + """Severity level of a safety violation. + + :cvar INFO: Informational level violation that does not require action + :cvar WARN: Warning level violation that suggests caution but allows continuation + :cvar ERROR: Error level violation that requires blocking or intervention + """ + + INFO = "info" + WARN = "warn" + ERROR = "error" + +class PolicyViolation(BaseModel): + """Details of a safety violation detected by content moderation. + + :param violation_level: Severity level of the violation + :param user_message: (Optional) Message to convey to the user about the violation + """ + + violation_level: ViolationLevel + + # what message should you convey to the user + user_message: str | None = None + +class ToolGuardCodeRunOutput(BaseModel): + violation: PolicyViolation | None = None + + +class ToolGuardCodeComponent(ComponentBase): + + def __init__(self, config:ToolGuardCodeComponentConfig): + super().__init__(config=config) + + @classmethod + def supported_phases(cls) -> Set[AgentPhase]: + """Return the supported agent phases.""" + return {AgentPhase.BUILDTIME, AgentPhase.RUNTIME} + + def _build(self, data: ToolGuardCodeBuildInput) -> ToolGuardsCodeGenerationResult: + raise NotImplementedError("Please use the aprocess() function in an async context") + + async def _abuild(self, data: ToolGuardCodeBuildInput) -> ToolGuardsCodeGenerationResult: + config = cast(ToolGuardCodeComponentConfig, self.config) + llm = TG_LLMEval(config.llm_client) + return await generate_guards_from_specs( + tools=data.tools, + tool_specs=data.toolguard_specs, + work_dir=data.out_dir, + llm=llm + ) + + def _run(self, data: ToolGuardCodeRunInput) -> ToolGuardCodeRunOutput: + code_root_dir = data.generated_guard_dir + tool_name = data.tool_name + tool_params = data.tool_args + with load_toolguards(code_root_dir) as toolguards: + from rt_toolguard.data_types import PolicyViolationException + try: + toolguards.check_toolcall(tool_name, tool_params, data.tool_invoker) + return ToolGuardCodeRunOutput() + except PolicyViolationException as e: + return ToolGuardCodeRunOutput(violation=PolicyViolation( + violation_level=ViolationLevel.ERROR, + user_message=str(e) + )) + + def _arun(self, data: ToolGuardCodeRunInput) -> ToolGuardCodeRunOutput: + return self._run(data) + diff --git a/altk/pre_tool/toolguard/toolguard_spec_component.py b/altk/pre_tool/toolguard/toolguard_spec_component.py new file mode 100644 index 0000000..c141758 --- /dev/null +++ b/altk/pre_tool/toolguard/toolguard_spec_component.py @@ -0,0 +1,46 @@ +import logging +import os +from typing import Callable, List, Set, cast +from langchain_core.tools import BaseTool +from pydantic import Field + +from altk.pre_tool.toolguard.llm_client import TG_LLMEval + +from ...core.toolkit import AgentPhase, ComponentBase, ComponentConfig, ComponentInput +from toolguard import ToolGuardSpec, generate_guard_specs + +logger = logging.getLogger(__name__) + +class ToolGuardSpecComponentConfig(ComponentConfig): + pass + +class ToolGuardSpecBuildInput(ComponentInput): + policy_text: str = Field(description="Text of the policy document file") + tools: List[Callable] | List[BaseTool] | str + out_dir: str + +ToolGuardSpecs=List[ToolGuardSpec] + +class ToolGuardSpecComponent(ComponentBase): + + def __init__(self, config:ToolGuardSpecComponentConfig): + super().__init__(config=config) + + @classmethod + def supported_phases(cls) -> Set[AgentPhase]: + return {AgentPhase.BUILDTIME, AgentPhase.RUNTIME} + + def _build(self, data: ToolGuardSpecBuildInput) -> ToolGuardSpecs: + raise NotImplementedError("Please use the aprocess() function in an async context") + + async def _abuild(self, data: ToolGuardSpecBuildInput) -> ToolGuardSpecs: + os.makedirs(data.out_dir, exist_ok=True) + config = cast(ToolGuardSpecComponentConfig, self.config) + llm = TG_LLMEval(config.llm_client) + return await generate_guard_specs( + policy_text=data.policy_text, + tools=data.tools, + work_dir=data.out_dir, + llm=llm + ) + \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 4fca8eb..9f58282 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,11 +33,6 @@ dependencies = [ "langchain-text-splitters>=1.0.0", "nltk>=3.9.1", "scipy>=1.15.3", - "pytest>=8.2.1", - "pytest-json-report>=1.5.0", - "pyright>=1.1.406", - "datamodel-code-generator>=0.34.0", - "mellea~=0.0.6", ] description = "The Agent Lifecycle Toolkit (ALTK) is a library of components to help agent builders improve their agent with minimal integration effort and setup." @@ -134,6 +129,10 @@ spotlight = [ "transformers>=4.53.3", ] +toolguard = [ + "toolguard>=0.1.14" +] + refraction = [ "nl2flow>=0.1.2; sys_platform != 'win32'", "sentence-transformers>=5.0.0", diff --git a/tests/pre_tool/toolguard/.gitignore b/tests/pre_tool/toolguard/.gitignore new file mode 100644 index 0000000..e6d35e7 --- /dev/null +++ b/tests/pre_tool/toolguard/.gitignore @@ -0,0 +1 @@ +outputs \ No newline at end of file diff --git a/altk/pre_tool/examples/__init__.py b/tests/pre_tool/toolguard/inputs/__init__.py similarity index 100% rename from altk/pre_tool/examples/__init__.py rename to tests/pre_tool/toolguard/inputs/__init__.py diff --git a/tests/pre_tool/toolguard/inputs/oas.json b/tests/pre_tool/toolguard/inputs/oas.json new file mode 100644 index 0000000..f6da957 --- /dev/null +++ b/tests/pre_tool/toolguard/inputs/oas.json @@ -0,0 +1,284 @@ +{ + "openapi": "3.0.3", + "info": { + "title": "Math Tools API", + "version": "1.0.0", + "description": "A simple API providing basic mathematical operations: addition, subtraction, multiplication, and division." + }, + "paths": { + "/tools/add": { + "post": { + "operationId": "add_tool", + "summary": "Add two numbers", + "description": "Add two numbers and return their sum.", + "x-input-examples": [ + "add_tool(1.0, 2.0)", + "add_tool(-5.5, 3.2)", + "add_tool(0, 100.75)", + "add_tool(1e3, 2e3)" + ], + "x-output-examples": [3.0, -2.3, 100.75, 3000.0], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "a": { + "type": "number", + "description": "The first number.", + "example": 5.0 + }, + "b": { + "type": "number", + "description": "The second number.", + "example": 3.0 + } + }, + "required": ["a", "b"] + } + } + } + }, + "responses": { + "200": { + "description": "The sum of a and b.", + "content": { + "application/json": { + "schema": { + "type": "number", + "example": 8.0 + } + } + } + } + } + } + }, + "/tools/subtract": { + "post": { + "operationId": "subtract_tool", + "summary": "Subtract one number from another", + "description": "Subtract one number from another and return the result.", + "x-input-examples": [ + "subtract_tool(10.0, 4.0)", + "subtract_tool(0.0, 1.5)", + "subtract_tool(-3.2, -5.8)", + "subtract_tool(1000, 1)" + ], + "x-output-examples": [6.0, -1.5, 2.6, 999.0], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "a": { + "type": "number", + "description": "The number to subtract from.", + "example": 10.0 + }, + "b": { + "type": "number", + "description": "The number to subtract.", + "example": 4.0 + } + }, + "required": ["a", "b"] + } + } + } + }, + "responses": { + "200": { + "description": "The result of a minus b.", + "content": { + "application/json": { + "schema": { + "type": "number", + "example": 6.0 + } + } + } + } + } + } + }, + "/tools/multiply": { + "post": { + "operationId": "multiply_tool", + "summary": "Multiply two numbers", + "description": "Multiply two numbers and return their product.", + "x-input-examples": [ + "multiply_tool(2.0, 3.5)", + "multiply_tool(-1.2, 4.0)", + "multiply_tool(0, 999.99)", + "multiply_tool(100, 0.01)" + ], + "x-output-examples": [7.0, -4.8, 0.0, 1.0], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "a": { + "type": "number", + "description": "The first number.", + "example": 2.0 + }, + "b": { + "type": "number", + "description": "The second number.", + "example": 3.5 + } + }, + "required": ["a", "b"] + } + } + } + }, + "responses": { + "200": { + "description": "The product of a and b.", + "content": { + "application/json": { + "schema": { + "type": "number", + "example": 7.0 + } + } + } + } + } + } + }, + "/tools/divide": { + "post": { + "operationId": "divide_tool", + "summary": "Divide one number by another", + "description": "Divide one number by another. The divisor must not be zero.", + "x-input-examples": [ + "divide_tool(10.0, 2.0)", + "divide_tool(3.0, 0.5)", + "divide_tool(-9.0, 3.0)", + "divide_tool(5.0, -2.5)" + ], + "x-output-examples": [5.0, 6.0, -3.0, -2.0], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "g": { + "type": "number", + "description": "The dividend.", + "example": 10.0 + }, + "h": { + "type": "number", + "description": "The divisor (must not be zero).", + "example": 2.0 + } + }, + "required": ["g", "h"] + } + } + } + }, + "responses": { + "200": { + "description": "The result of g divided by h.", + "content": { + "application/json": { + "schema": { + "type": "number", + "example": 5.0 + } + } + } + }, + "400": { + "description": "Division by zero is not allowed.", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "error": { + "type": "string", + "example": "Division by zero is not allowed." + } + } + } + } + } + } + } + } + }, + "/tools/map_kdi_number": { + "post": { + "operationId": "map_kdi_number", + "summary": "Compute the kdi number", + "description": "Compute the kdi number", + "x-input-examples": [ + "map_kdi_number(2.0)" + ], + "x-output-examples": [-2.0], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "i": { + "type": "number", + "description": "The number to compute.", + "example": 10.0 + } + }, + "required": ["g", "h"] + } + } + } + }, + "responses": { + "200": { + "description": "The kdi index.", + "content": { + "application/json": { + "schema": { + "type": "number", + "example": 5.0 + } + } + } + }, + "400": { + "description": "Bad request.", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "error": { + "type": "string", + "example": "not a number" + } + } + } + } + } + } + } + } + } + } +} diff --git a/altk/pre_tool/toolguard/examples/calculator_example/policy_document.md b/tests/pre_tool/toolguard/inputs/policy_doc.md similarity index 80% rename from altk/pre_tool/toolguard/examples/calculator_example/policy_document.md rename to tests/pre_tool/toolguard/inputs/policy_doc.md index ce70740..f476eb8 100644 --- a/altk/pre_tool/toolguard/examples/calculator_example/policy_document.md +++ b/tests/pre_tool/toolguard/inputs/policy_doc.md @@ -15,22 +15,22 @@ The calculator supports the following operations: - Subtraction (`-`) - Multiplication (`*`) - Division (`/`) -- Map KDI +- Map KDI ## Operation Constraints -- **Division by Zero is Not Allowed** - The calculator **must not** allow division by zero. - If a user attempts to divide by zero, the operation must be rejected and an appropriate error message should be shown +- **Division by Zero is Not Allowed** + The calculator **must not** allow division by zero. + If a user attempts to divide by zero, the operation must be rejected and an appropriate error message should be shown (e.g., `"Error: Division by zero is not allowed."`). -- **Summing Numbers Whose Product is 365 is Not Allowed** - The calculator **must not** allow addition of two or more numbers if their multiplication result equals `365`. - For example, adding `5 + 73` should be disallowed, because `5 * 73 = 365` . - In such cases, the operation must be rejected with an error like: +- **Summing Numbers Whose Product is 365 is Not Allowed** + The calculator **must not** allow addition of two or more numbers if their multiplication result equals `365`. + For example, adding `5 + 73` should be disallowed, because `5 * 73 = 365` . + In such cases, the operation must be rejected with an error like: `"Error: Addition of numbers whose product equals 365 is not allowed."` - - **Multiplying Numbers When Any Operand’s KDI Value Equals 6.28 is Not Allowed** + - **Multiplying Numbers When Any Operand’s KDI Value Equals 6.28 is Not Allowed** The calculator must not allow multiplication if any of the operands correspond to a number whose KDI value equals 6.28. For example, if a number x has KDI(x) = 6.28, then an operation such as x * 2 must be rejected. In such cases, the calculator should display an error message such as: diff --git a/tests/pre_tool/toolguard/inputs/step1/add_tool.json b/tests/pre_tool/toolguard/inputs/step1/add_tool.json new file mode 100644 index 0000000..21f5a53 --- /dev/null +++ b/tests/pre_tool/toolguard/inputs/step1/add_tool.json @@ -0,0 +1,25 @@ +{ + "tool_name": "add_tool", + "policy_items": [ + { + "name": "Prohibition of Adding Numbers Whose Product Equals 365", + "description": "The add_tool must not allow addition of two or more numbers if their multiplication result equals 365. For example, adding 5 + 73 should be disallowed, because 5 * 73 = 365. In such cases, the operation must be rejected with an error like: 'Error: Addition of numbers whose product equals 365 is not allowed.'", + "references": [ + "The calculator must not allow addition of two or more numbers if their multiplication result equals 365. For example, adding 5 + 73 should be disallowed, because 5 * 73 = 365. In such cases, the operation must be rejected with an error like: 'Error: Addition of numbers whose product equals 365 is not allowed.'" + ], + "iteration_added": 0, + "violation_examples": [ + "A user tries to add 5 and 73. Since 5 * 73 equals 365, this operation should be rejected with an error message: 'Error: Addition of numbers whose product equals 365 is not allowed.'", + "The user attempts to add two numbers, 1 and 365, using the add_tool. This is a violation because 1 * 365 is also 365 and should be met with an error message.", + "Adding together the numbers 10 and 36.5 should be disallowed because their product is 365. An error message should be shown to prevent this operation.", + "A user inputs 365 and 1 into the add_tool. This action violates the policy as 365 multiplied by 1 is 365, warranting an error message rejection.", + "The system tries to perform the addition of -5 and -73 using the add_tool. Despite the negative signs, their product still equals 365 and thus contravenes the policy, requiring an error message to be displayed." + ], + "compliance_examples": [ + "Performing an addition of 10 and 20, where their product is 200, complies with the policy since it does not equal 365, and the addition proceeds without errors.", + "A user adds together the numbers 30 and 40. Since their product is 1200, which is not 365, this operation adheres to the policy without issue.", + "Addition of numbers 123 and 3 is compliant as they result in 369 when multiplied and hence clear the policy condition allowing progressive operation." + ] + } + ] +} \ No newline at end of file diff --git a/tests/pre_tool/toolguard/inputs/step1/divide_tool.json b/tests/pre_tool/toolguard/inputs/step1/divide_tool.json new file mode 100644 index 0000000..0744ff4 --- /dev/null +++ b/tests/pre_tool/toolguard/inputs/step1/divide_tool.json @@ -0,0 +1,23 @@ +{ + "tool_name": "divide_tool", + "policy_items": [ + { + "name": "Division by Zero is Not Allowed", + "description": "The calculator must not allow division by zero. If a user attempts to divide by zero, the operation must be rejected and an appropriate error message should be shown (e.g., 'Error: Division by zero is not allowed.')", + "references": [ + "The calculator must not allow division by zero. If a user attempts to divide by zero, the operation must be rejected and an appropriate error message should be shown (e.g., 'Error: Division by zero is not allowed.')" + ], + "iteration_added": 0, + "violation_examples": [ + "A user attempts to divide 15 by 0, but the program proceeds with the operation, resulting in either an undefined behavior or a system error instead of an appropriate error message being shown.", + "The system calculates the result of 42 divided by 0 without any checks, ignoring the divisor being zero and directly providing an undefined or erroneous result.", + "A user input of 80 divided by 0 is processed and the system returns a zero or a nonsensical number instead of stopping the operation and displaying an error message about division by zero.", + "A user attempts to divide 25 by 0 and the application immediately rejects the operation, displaying the message: 'Error: Division by zero is not allowed.'" + ], + "compliance_examples": [ + "Upon entering inputs of 0 divided by 15, the tool successfully executes the operation and returns 0, complying with the division rules since the divisor (15) is non-zero.", + "A calculation involving 132 divided by 11 is processed correctly, returning a result of 12 without any error since it adheres to all division rules." + ] + } + ] +} \ No newline at end of file diff --git a/tests/pre_tool/toolguard/inputs/step1/map_kdi_number.json b/tests/pre_tool/toolguard/inputs/step1/map_kdi_number.json new file mode 100644 index 0000000..2aae27e --- /dev/null +++ b/tests/pre_tool/toolguard/inputs/step1/map_kdi_number.json @@ -0,0 +1,4 @@ +{ + "tool_name": "map_kdi_number", + "policy_items": [] +} \ No newline at end of file diff --git a/tests/pre_tool/toolguard/inputs/step1/multiply_tool.json b/tests/pre_tool/toolguard/inputs/step1/multiply_tool.json new file mode 100644 index 0000000..612e7c3 --- /dev/null +++ b/tests/pre_tool/toolguard/inputs/step1/multiply_tool.json @@ -0,0 +1,23 @@ +{ + "tool_name": "multiply_tool", + "policy_items": [ + { + "name": "Multiplication KDI Constraint", + "description": "Multiplying numbers when any operand\u2019s KDI value equals 6.28 is not allowed. The calculator must not allow multiplication if any of the operands correspond to a number whose KDI value equals 6.28. For example, if a number x has KDI(x) = 6.28, then an operation such as x * 2 must be rejected.", + "references": [ + "Multiplying Numbers When Any Operand\u2019s KDI Value Equals 6.28 is Not Allowed\nThe calculator must not allow multiplication if any of the operands correspond to a number whose KDI value equals 6.28. For example, if a number x has KDI(x) = 6.28, then an operation such as x * 2 must be rejected. In such cases, the calculator should display an error message such as: \"Error: Multiplication involving a number with KDI value 6.28 is not allowed.\"" + ], + "iteration_added": 0, + "violation_examples": [ + "Attempt to multiply 5 by 3 where KDI(5) = 6.28. This violates the policy since using any operand with KDI value 6.28 in multiplication is not permitted.", + "Operation tries to calculate the product of 2 and 10, but KDI(10) yields 6.28. This is a breach of policy because multiplying a number mapped to KDI value 6.28 is disallowed.", + "A multiplication command of 8 * 4 is executed, and upon retrieval, KDI(4) = 6.28. Using 4 in multiplication should be rejected per the policy." + ], + "compliance_examples": [ + "Perform multiplication with operands 7 and 3. Considering that KDI(7) != 6.28 and KDI(3) != 6.28, the multiplication is allowed.", + "Calculate the product of 2.5 and 4.2, where neither operand has a KDI value of 6.28, complying with the given constraints.", + "Execute multiplication of 9 and 1, confirming beforehand that the KDI values for 9 and 1 do not equal 6.28, thus following policy rules." + ] + } + ] +} \ No newline at end of file diff --git a/tests/pre_tool/toolguard/inputs/step1/subtract_tool.json b/tests/pre_tool/toolguard/inputs/step1/subtract_tool.json new file mode 100644 index 0000000..5b96883 --- /dev/null +++ b/tests/pre_tool/toolguard/inputs/step1/subtract_tool.json @@ -0,0 +1,4 @@ +{ + "tool_name": "subtract_tool", + "policy_items": [] +} \ No newline at end of file diff --git a/tests/pre_tool/toolguard/inputs/tool_functions.py b/tests/pre_tool/toolguard/inputs/tool_functions.py new file mode 100644 index 0000000..dfab5e4 --- /dev/null +++ b/tests/pre_tool/toolguard/inputs/tool_functions.py @@ -0,0 +1,66 @@ +def divide_tool(g: float, h: float) -> float: + """Divides one number by another. + + Args: + g (float): The dividend. + h (float): The divisor (must not be zero). + + Returns: + float: The result of g divided by h. + + Raises: + ZeroDivisionError: If h is zero. + """ + return g / h + + +def add_tool(a: float, b: float) -> float: + """Adds two numbers. + + Args: + a (float): The first number. + b (float): The second number. + + Returns: + float: The sum of a and b. + """ + return a + b + + +def subtract_tool(a: float, b: float) -> float: + """Subtracts one number from another. + + Args: + a (float): The number to subtract from. + b (float): The number to subtract. + + Returns: + float: The result of a minus b. + """ + return a - b + + +def multiply_tool(a: float, b: float) -> float: + """Multiplies two numbers. + + Args: + a (float): The first number. + b (float): The second number. + + Returns: + float: The product of a and b. + """ + return a * b + +def map_kdi_number(i: float) -> float: + """ + return the mapping of the numer i to it's kdi value + + Args: + i (float): The number to map. + + + Returns: + float: The value of the dki of the given number. + """ + return 3.14 * i \ No newline at end of file diff --git a/tests/pre_tool/toolguard/inputs/tool_langchain.py b/tests/pre_tool/toolguard/inputs/tool_langchain.py new file mode 100644 index 0000000..57b8e2b --- /dev/null +++ b/tests/pre_tool/toolguard/inputs/tool_langchain.py @@ -0,0 +1,70 @@ +from langchain_core.tools import tool + +@tool() +def divide_tool(g: float, h: float) -> float: + """Divides one number by another. + + Args: + g (float): The dividend. + h (float): The divisor (must not be zero). + + Returns: + float: The result of g divided by h. + + Raises: + ZeroDivisionError: If h is zero. + """ + return g / h + +@tool() +def add_tool(a: float, b: float) -> float: + """Adds two numbers. + + Args: + a (float): The first number. + b (float): The second number. + + Returns: + float: The sum of a and b. + """ + return a + b + +@tool() +def subtract_tool(a: float, b: float) -> float: + """Subtracts one number from another. + + Args: + a (float): The number to subtract from. + b (float): The number to subtract. + + Returns: + float: The result of a minus b. + """ + return a - b + +@tool() +def multiply_tool(a: float, b: float) -> float: + """Multiplies two numbers. + + Args: + a (float): The first number. + b (float): The second number. + + Returns: + float: The product of a and b. + """ + return a * b + +@tool() +def map_kdi_number(i: float) -> float: + """ + return the mapping of the numer i to it's kdi value + + Args: + i (float): The number to map. + + + Returns: + float: The value of the dki of the given number. + """ + return 3.14 * i \ No newline at end of file diff --git a/tests/pre_tool/toolguard/inputs/tool_methods.py b/tests/pre_tool/toolguard/inputs/tool_methods.py new file mode 100644 index 0000000..7154857 --- /dev/null +++ b/tests/pre_tool/toolguard/inputs/tool_methods.py @@ -0,0 +1,66 @@ +class CalculatorTools: + """A collection of basic arithmetic tools.""" + + def divide_tool(self, g: float, h: float) -> float: + """Divides one number by another. + + Args: + g (float): The dividend. + h (float): The divisor (must not be zero). + + Returns: + float: The result of g divided by h. + + Raises: + ZeroDivisionError: If h is zero. + """ + return g / h + + def add_tool(self, a: float, b: float) -> float: + """Adds two numbers. + + Args: + a (float): The first number. + b (float): The second number. + + Returns: + float: The sum of a and b. + """ + return a + b + + def subtract_tool(self, a: float, b: float) -> float: + """Subtracts one number from another. + + Args: + a (float): The number to subtract from. + b (float): The number to subtract. + + Returns: + float: The result of a minus b. + """ + return a - b + + def multiply_tool(self, a: float, b: float) -> float: + """Multiplies two numbers. + + Args: + a (float): The first number. + b (float): The second number. + + Returns: + float: The product of a and b. + """ + return a * b + + def map_kdi_number(self, i: float) -> float: + """ + return the mapping of the numer i to it's kdi value + + Args: + i (float): The number to map. + + + Returns: + float: The value of the dki of the given number. + """ + return 3.14 * i \ No newline at end of file diff --git a/tests/pre_tool/toolguard/test_tool_guard_calculator_policy.py b/tests/pre_tool/toolguard/test_tool_guard_calculator_policy.py deleted file mode 100644 index 0301402..0000000 --- a/tests/pre_tool/toolguard/test_tool_guard_calculator_policy.py +++ /dev/null @@ -1,108 +0,0 @@ -import os -import asyncio -import dotenv -import pytest - -from altk.pre_tool.toolguard.pre_tool_guard import PreToolGuardComponent -from altk.core.llm.base import get_llm -from altk.pre_tool.toolguard.core import ( - ToolGuardBuildInput, - ToolGuardBuildInputMetaData, - ToolGuardRunInput, - ToolGuardRunInputMetaData, -) - -import tempfile -import shutil - -dotenv.load_dotenv() - - -def divide_tool(g: float, h: float) -> float: - """ - Divide one number by another. - - Parameters - ---------- - g : float - The dividend. - h : float - The divisor (must not be zero). - - Returns - ------- - float - The result of a divided by b. - """ - return g / h - - -WATSONX_CREDS_AVAILABLE = all([os.getenv("WX_API_KEY"), os.getenv("WX_PROJECT_ID")]) - - -@pytest.mark.asyncio -@pytest.mark.skipif(not WATSONX_CREDS_AVAILABLE, reason="WatsonX credentials not set") -async def test_tool_guard_calculator_policy(monkeypatch): - # sets the genpy env variables necessary for toolguard - monkeypatch.setenv("TOOLGUARD_GENPY_BACKEND_NAME", "litellm") - monkeypatch.setenv( - "TOOLGUARD_GENPY_MODEL_ID", "watsonx/mistralai/mistral-medium-2505" - ) - - work_dir = tempfile.mkdtemp() - tools = [divide_tool] - policy_text = "The calculator must not allow division by zero." - - # OPENAILiteLLMClientOutputVal = get_llm("watsonx.output_val") - # validating_llm_client = OPENAILiteLLMClientOutputVal( - # model="gpt-4o-2024-08-06" - # ) - WatsonXClientOutputVal = get_llm("watsonx.output_val") - validating_llm_client = WatsonXClientOutputVal( - model_name="mistralai/mistral-medium-2505", - api_key=os.getenv("WX_API_KEY"), - project_id=os.getenv("WX_PROJECT_ID"), - url=os.getenv("WX_URL", "https://us-south.ml.cloud.ibm.com"), - ) - - tool_guard_altk = PreToolGuardComponent( - tools=tools, workdir=work_dir, app_name="calculator" - ) - build_input = ToolGuardBuildInput( - metadata=ToolGuardBuildInputMetaData( - policy_text=policy_text, - short1=True, - validating_llm_client=validating_llm_client, - ) - ) - - await tool_guard_altk._build(build_input) - - test_options = [ - ( - "Can you please calculate how much is 3/4?", - "divide_tool", - {"g": 3, "h": 4}, - True, - ), - ( - "Can you please calculate how much is 5/0?", - "divide_tool", - {"g": 5, "h": 0}, - False, - ), - ] - for user_query, tool_name, tool_params, expected in test_options: # noqa: B007 - run_input = ToolGuardRunInput( - metadata=ToolGuardRunInputMetaData( - tool_name=tool_name, - tool_parms=tool_params, - ) - ) - - run_output = tool_guard_altk._run(run_input) - print(run_output) - passed = not run_output.output.error_message - assert passed == expected - - shutil.rmtree(work_dir) diff --git a/tests/pre_tool/toolguard/test_toolguard_code.py b/tests/pre_tool/toolguard/test_toolguard_code.py new file mode 100644 index 0000000..40e08f7 --- /dev/null +++ b/tests/pre_tool/toolguard/test_toolguard_code.py @@ -0,0 +1,224 @@ +""" +End-to-end test for ToolGuard code generation & runtime evaluation +using a set of arithmetic tools. + +This test: +1. Loads tool policies +2. Generates guard code +3. Runs the guards in runtime mode +4. Verifies correct functionality and violations +""" + +from datetime import datetime +import os +from pathlib import Path +from typing import Dict, cast + +import dotenv +import pytest + +from altk.core.llm.base import BaseLLMClient +from altk.pre_tool.toolguard import ( + ToolGuardCodeComponent, + ToolGuardCodeBuildInput, +) +from toolguard.data_types import ( + load_tool_policy, +) +from toolguard.runtime import ( + ToolFunctionsInvoker, + ToolGuardsCodeGenerationResult, +) +from altk.pre_tool.toolguard.toolguard_code_component import ( + ToolGuardCodeComponentConfig, + ToolGuardCodeRunInput, + ToolGuardCodeRunOutput, + ViolationLevel, +) +from altk.core.toolkit import AgentPhase + +# The calculator tools under test +from .inputs.tool_functions import ( + divide_tool, + add_tool, + multiply_tool, + subtract_tool, + map_kdi_number, +) + +# Load environment variables +dotenv.load_dotenv() + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +@pytest.fixture +def work_dir(): + """Creates a temporary folder for test output and cleans it afterward.""" + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + dir_path = str(Path(__file__).parent / "outputs" / f"work_{timestamp}") + print("Temporary work dir created:", dir_path) + + yield dir_path + + # shutil.rmtree(dir_path) + # print("Temporary work dir removed:", dir_path) + +def get_llm()->BaseLLMClient: + from altk.core.llm.providers.ibm_watsonx_ai.ibm_watsonx_ai import WatsonxLLMClient + return WatsonxLLMClient( + model_name="meta-llama/llama-4-maverick-17b-128e-instruct-fp8", + api_key=os.getenv("WATSONX_API_KEY"), + project_id = os.getenv("WATSONX_PROJECT_ID"), + url=os.getenv("WATSONX_URL"), + ) + + # from altk.core.llm.providers.openai.openai import AsyncAzureOpenAIClient + # return AsyncAzureOpenAIClient( + # model="gpt-4o-2024-08-06", + # api_key=os.getenv("AZURE_OPENAI_API_KEY"), + # azure_endpoint=os.getenv("AZURE_API_BASE"), + # api_version="2024-08-01-preview" + # ) + + # from altk.core.llm.providers.litellm.litellm import LiteLLMClient + # return LiteLLMClient( + # model_name=os.getenv("TOOLGUARD_GENPY_MODEL_ID"), + # api_key=os.getenv("TOOLGUARD_GENPY_MODEL_API_KEY"), + # base_url=os.getenv("TOOLGUARD_GENPY_MODEL_BASE_URL"), + # ) + + # from altk.core.llm.providers.openai.openai import AsyncOpenAIClient + # return AsyncOpenAIClient( + # model=os.getenv("TOOLGUARD_GENPY_MODEL_ID"), + # api_key=os.getenv("TOOLGUARD_GENPY_MODEL_API_KEY"), + # url=os.getenv("TOOLGUARD_GENPY_MODEL_BASE_URL"), + # ) + + # from altk.core.llm.providers.openai.openai import AsyncOpenAIClientOutputVal + # return AsyncOpenAIClientOutputVal( + # model=os.getenv("TOOLGUARD_GENPY_MODEL_ID"), + # api_key=os.getenv("TOOLGUARD_GENPY_MODEL_API_KEY"), + # url=os.getenv("TOOLGUARD_GENPY_MODEL_BASE_URL"), + # ) + +# --------------------------------------------------------------------------- +# Test: ToolGuard verification for the calculator tool set +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_tool_guard_calculator_policy(work_dir: str): + # Tools to be guarded + funcs = [divide_tool, add_tool, multiply_tool, subtract_tool, map_kdi_number] + + # Build ToolGuard component + toolguard_code = ToolGuardCodeComponent( + ToolGuardCodeComponentConfig(llm_client=get_llm()) + ) + + # Load policy JSON files from /step1 + policy_dir = Path(__file__).parent / "inputs" / "step1" + specs = [ + load_tool_policy(str(policy_dir / f"{tool.__name__}.json"), tool.__name__) + for tool in funcs + ] + + # Prepare build input for guard code generation + input = ToolGuardCodeBuildInput( + tools=funcs, + out_dir=work_dir, + toolguard_specs=specs, + ) + + #Toolguarg code generation + build_output = cast(ToolGuardsCodeGenerationResult, + await toolguard_code.aprocess(input, AgentPhase.BUILDTIME)) + # output = load_toolguard_code_result(work_dir) + + # Expected guarded tools + expected_tools = ["multiply_tool", "divide_tool", "add_tool"] + + # Basic structure assertions + assert build_output.out_dir + assert build_output.domain + assert len(build_output.tools) == len(expected_tools) + + # Validate guard components for each tool + for tool_name in expected_tools: + result = build_output.tools[tool_name] + + assert len(result.tool.policy_items) == 1 + assert result.guard_fn_name + assert result.guard_file + assert len(result.item_guard_files) == 1 + assert result.item_guard_files[0].content # Generated guard code + assert len(result.test_files) == 1 + assert result.test_files[0].content + + # ----------------------------------------------------------------------- + # Runtime Testing + # ----------------------------------------------------------------------- + + tool_invoker = ToolFunctionsInvoker(funcs) + + def call(tool_name: str, args: Dict) -> ToolGuardCodeRunOutput: + """Executes a tool through its guard code.""" + return cast( + ToolGuardCodeRunOutput, + toolguard_code.process( + ToolGuardCodeRunInput( + generated_guard_dir=build_output.out_dir, + tool_name=tool_name, + tool_args=args, + tool_invoker=tool_invoker, + ), + AgentPhase.RUNTIME, + ), + ) + + def assert_complies(tool_name: str, args: Dict): + """Asserts that no violation occurs.""" + assert call(tool_name, args).violation is None + + def assert_violates(tool_name: str, args: Dict): + """Asserts that a violation occurs with level ERROR.""" + res = call(tool_name, args) + assert res.violation + assert res.violation.violation_level == ViolationLevel.ERROR + assert res.violation.user_message + + # Valid input cases ----------------------------------------------------- + assert_complies("divide_tool", {"g": 5, "h": 4}) + assert_complies("add_tool", {"a": 5, "b": 4}) + assert_complies("subtract_tool", {"a": 5, "b": 4}) + assert_complies("multiply_tool", {"a": 5, "b": 4}) + assert_complies("map_kdi_number", {"i": 5}) + + # Violation cases ------------------------------------------------------- + assert_violates("divide_tool", {"g": 5, "h": 0}) + assert_violates("add_tool", {"a": 5, "b": 73}) + assert_violates("add_tool", {"a": 73, "b": 5}) + + # Violations for multiply_tool based on custom rules + assert_violates("multiply_tool", {"a": 2, "b": 73}) + assert_violates("multiply_tool", {"a": 22, "b": 2}) + + +# --------------------------------------------------------------------------- +# Optional: Main entry point for directly running the test without pytest +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + import asyncio + + async def main(): + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + work_dir = str(Path(__file__).parent / "outputs" / f"work_{timestamp}") + print("[main] work dir created:", work_dir) + + # Call the async test function directly + await test_tool_guard_calculator_policy(work_dir) + print("[main] Test completed successfully.") + + asyncio.run(main()) \ No newline at end of file diff --git a/tests/pre_tool/toolguard/test_toolguard_specs.py b/tests/pre_tool/toolguard/test_toolguard_specs.py new file mode 100644 index 0000000..e4f1b59 --- /dev/null +++ b/tests/pre_tool/toolguard/test_toolguard_specs.py @@ -0,0 +1,143 @@ +import asyncio +import os +import shutil +from datetime import datetime +from pathlib import Path +from typing import cast + +import dotenv +import pytest + +from altk.pre_tool.toolguard.toolguard_spec_component import ( + ToolGuardSpecBuildInput, + ToolGuardSpecComponent, + ToolGuardSpecComponentConfig, + ToolGuardSpecs, +) +from altk.core.toolkit import AgentPhase +from altk.core.llm.base import get_llm + +from .inputs.tool_functions import ( + divide_tool, + add_tool, + subtract_tool, + map_kdi_number, + multiply_tool, +) + +dotenv.load_dotenv() + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- +@pytest.fixture +def out_dir(): + """ + Create a timestamped directory for test output, then delete it after the test. + """ + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + dir_path = str(Path(__file__).parent / "outputs" / f"work_{timestamp}") + + print("Temporary work dir created:", dir_path) + yield dir_path + + shutil.rmtree(dir_path) + print("Temporary work dir removed:", dir_path) + + +# --------------------------------------------------------------------------- +# Main Test +# --------------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_tool_guard_calculator_policy(out_dir: str): + funcs = [ + divide_tool, + add_tool, + multiply_tool, + subtract_tool, + map_kdi_number, + ] + + policy_text = """ + The calculator must not allow division by zero. + The calculator must not allow multiplication if any of the operands + correspond to a number whose KDI value equals 6.28. + """ + + # Example alternative LLM: + # LLMClient = get_llm("litellm.output_val") + # llm_client = LLMClient( + # model_name="gpt-4o-2024-08-06", + # custom_llm_provider="azure", + # ) + + LLMClient = get_llm("watsonx.output_val") + llm_client = LLMClient( + model_name="mistralai/mistral-medium-2505", + api_key=os.getenv("WX_API_KEY"), + project_id=os.getenv("WX_PROJECT_ID"), + url=os.getenv("WX_URL", "https://us-south.ml.cloud.ibm.com"), + ) + + toolguard_spec = ToolGuardSpecComponent( + ToolGuardSpecComponentConfig(llm_client=llm_client) + ) + + input_data = ToolGuardSpecBuildInput( + policy_text=policy_text, + tools=funcs, + out_dir=out_dir, + ) + + specs = cast( + ToolGuardSpecs, + await toolguard_spec.aprocess( + data=input_data, + phase=AgentPhase.BUILDTIME, + ), + ) + + # Validate number of results + assert len(specs) == len(funcs) + specs_by_name = {spec.tool_name: spec for spec in specs} + + # Tools that should have policy items + expected_tools = ["multiply_tool", "divide_tool"] + + # Tools that should produce no policy items + empty_tools = ["add_tool", "subtract_tool", "map_kdi_number"] + + # Validate expected tools have populated spec items + for tool_name in expected_tools: + spec = specs_by_name[tool_name] + + assert len(spec.policy_items) == 1 + item = spec.policy_items[0] + + assert item.name + assert item.description + assert len(item.references) > 0 + assert item.compliance_examples and len(item.compliance_examples) > 1 + assert item.violation_examples and len(item.violation_examples) > 1 + + # Validate tools that should be empty + for tool_name in empty_tools: + spec = specs_by_name[tool_name] + assert len(spec.policy_items) == 0 + + +# --------------------------------------------------------------------------- +# Optional: Run test directly (without pytest) +# --------------------------------------------------------------------------- +if __name__ == "__main__": + + async def main(): + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + work_dir = str(Path(__file__).parent / "outputs" / f"work_{timestamp}") + + print("[main] work dir created:", work_dir) + await test_tool_guard_calculator_policy(work_dir) + print("[main] Test completed successfully.") + + asyncio.run(main())