diff --git a/.github/PULL_REQUEST_TEMPLATE/component.md b/.github/PULL_REQUEST_TEMPLATE/component.md new file mode 100644 index 00000000..9c5cd619 --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE/component.md @@ -0,0 +1,27 @@ +# Component PR + +Use this template when adding or modifying components in `mellea/stdlib/components/`. + +## Description +- [ ] Link to Issue: + + + +## Implementation Checklist + +### Protocol Compliance +- [ ] `parts()` returns list of constituent parts (Components or CBlocks) +- [ ] `format_for_llm()` returns TemplateRepresentation or string +- [ ] `_parse(computed: ModelOutputThunk)` parses model output correctly into the specified Component return type + +### Content Blocks +- [ ] CBlock used appropriately for text content +- [ ] ImageBlock used for image content (if applicable) + +### Integration +- [ ] Component exported in `mellea/stdlib/components/__init__.py` or, if you are adding a library of components, from your sub-module + +### Testing +- [ ] Tests added to `tests/components/` +- [ ] New code has 100% coverage +- [ ] Ensure existing tests and github automation passes (a maintainer will kick off the github automation when the rest of the PR is populated) diff --git a/.github/PULL_REQUEST_TEMPLATE/misc.md b/.github/PULL_REQUEST_TEMPLATE/misc.md new file mode 100644 index 00000000..663a681e --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE/misc.md @@ -0,0 +1,18 @@ +# Misc PR + +## Type of PR + +- [ ] Bug Fix +- [ ] New Feature +- [ ] Documentation +- [ ] Other + +## Description +- [ ] Link to Issue: + + + +### Testing +- [ ] Tests added to the respective file if code was changed +- [ ] New code has 100% coverage if code as added +- [ ] Ensure existing tests and github automation passes (a maintainer will kick off the github automation when the rest of the PR is populated) diff --git a/.github/PULL_REQUEST_TEMPLATE/requirement.md b/.github/PULL_REQUEST_TEMPLATE/requirement.md new file mode 100644 index 00000000..91b03e41 --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE/requirement.md @@ -0,0 +1,30 @@ +# Requirement PR + +Use this template when adding or modifying requirements in `mellea/stdlib/requirements/`. + +## Description +- [ ] Link to Issue: + + + +## Implementation Checklist + +### Base Class +- [ ] Extends appropriate base class: + - `Requirement` - standard requirement + - `ALoraRequirement` - uses specialized Intrinsic/Adapter for generation-based validation + +### Validation Logic +- [ ] `validation_fn` defined (if using Python-based validation) + - [ ] re-usable functionality within the validation_fn should be separated out into `mellea/stdlib/tools/` +- [ ] `validate` returns a `ValidationResult` with + - [ ] a `thunk` and `context` if using a backend to generate + - [ ] a specific `reason` and `score` when possible + +### Integration +- [ ] Requirement exported in `mellea/stdlib/requirements/__init__.py` or, if you are adding a library of requirements, from your sub-module + +### Testing +- [ ] Tests added to `tests/requirements/` +- [ ] New code has 100% coverage +- [ ] Ensure existing tests and github automation passes (a maintainer will kick off the github automation when the rest of the PR is populated) diff --git a/.github/PULL_REQUEST_TEMPLATE/sampling.md b/.github/PULL_REQUEST_TEMPLATE/sampling.md new file mode 100644 index 00000000..5ce25029 --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE/sampling.md @@ -0,0 +1,28 @@ +# Sampling Strategy PR + +Use this template when adding or modifying sampling strategies in `mellea/stdlib/sampling/`. + +## Description +- [ ] Link to Issue: + + + +## Implementation Checklist + +### Base Class +- [ ] Extends appropriate base class: + - `BaseSamplingStrategy` if your changes are mostly modifying the `repair` and/or `select_from_failure` functions + - `SamplingStrategy` if your changes involve a new `sample` method + - Other defined sampling strategies if your implementation is similar to existing implementations + +### Return Value +- [ ] Returns a properly typed `SamplingResult`. Specifically, this means: + - `ModelOutputThunk`s in `sample_generations` are properly typed from the Component and the `parsed_repr` is the expected type. + +### Integration +- [ ] Strategy exported in `mellea/stdlib/sampling/__init__.py` + +### Testing +- [ ] Tests added to `tests/sampling/` +- [ ] New code has 100% coverage +- [ ] Ensure existing tests and github automation passes (a maintainer will kick off the github automation when the rest of the PR is populated) diff --git a/.github/PULL_REQUEST_TEMPLATE/tool.md b/.github/PULL_REQUEST_TEMPLATE/tool.md new file mode 100644 index 00000000..bf191c17 --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE/tool.md @@ -0,0 +1,22 @@ +# Tool PR + +Use this template when adding or modifying components in `mellea/stdlib/tools/`. + +## Description +- [ ] Link to Issue: + + + +## Implementation Checklist + +### Protocol Compliance +- [ ] Ensure compatibility with existing backends and providers + - For most tools being added as functions, this means that calling `convert_function_to_tool` works + +### Integration +- [ ] Tool exported in `mellea/stdlib/tools/__init__.py` or, if you are adding a library of components, from your sub-module + +### Testing +- [ ] Tests added to `tests/stdlib/tools/` +- [ ] New code has 100% coverage +- [ ] Ensure existing tests and github automation passes (a maintainer will kick off the github automation when the rest of the PR is populated) diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md new file mode 100644 index 00000000..9171ee5d --- /dev/null +++ b/.github/pull_request_template.md @@ -0,0 +1,13 @@ +# Pull Request + +NOTE: Please ensure you have an issue that has been acknowledged by a core contributor and routed you to open a pull request against this repository. Otherwise, please open an issue before continuing with this pull request. + +Only modify this text by checking one of the boxes below. This comment will be overwritten with a specific pull request template based off your choice. + +Type of Pull Request: + +- [ ] Component +- [ ] Requirement +- [ ] Sampling Strategy +- [ ] Tool +- [ ] Misc: Bug Fix, New Feature, Documentation Update, Other diff --git a/.github/workflows/pr-update.yml b/.github/workflows/pr-update.yml new file mode 100644 index 00000000..33505a0f --- /dev/null +++ b/.github/workflows/pr-update.yml @@ -0,0 +1,73 @@ +name: PR Bot + +on: + pull_request_target: + types: [opened, edited] + +jobs: + update-pr-body: + runs-on: ubuntu-latest + if: ${{ !contains(github.event.pull_request.body, 'mellea-pr-edited-marker') }} + permissions: + pull-requests: write + contents: read + steps: + - name: Checkout code # Checks out the base branch, not PR branch. + uses: actions/checkout@v4 + + - name: Detect PR type from checkboxes + id: detect-type + env: + PR_BODY: ${{ github.event.pull_request.body }} + run: | + PR_TYPE="" + + # Check for checked boxes (supports [x] and [X]) + if echo "$PR_BODY" | grep -qi '\[x\] Component'; then + PR_TYPE="component" + elif echo "$PR_BODY" | grep -qi '\[x\] Requirement'; then + PR_TYPE="requirement" + elif echo "$PR_BODY" | grep -qi '\[x\] Sampling Strategy'; then + PR_TYPE="sampling" + elif echo "$PR_BODY" | grep -qi '\[x\] Tool'; then + PR_TYPE="tool" + elif echo "$PR_BODY" | grep -qi '\[x\] Misc'; then + PR_TYPE="misc" + fi + + if [ -z "$PR_TYPE" ]; then + echo "::error::No PR type selected. Please check one of of the boxes from the original pr template." + exit 1 + fi + + echo "pr_type=$PR_TYPE" >> "$GITHUB_OUTPUT" + echo "Detected PR type: $PR_TYPE" + + - name: Update PR body with checklist + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + PR_BODY: ${{ github.event.pull_request.body }} + PR_TYPE: ${{ steps.detect-type.outputs.pr_type }} + run: | + TEMPLATE_FILE=".github/PULL_REQUEST_TEMPLATE/${PR_TYPE}.md" + + if [ -f "$TEMPLATE_FILE" ]; then + MARKER="" + TEMPLATE_CONTENT=$(cat "$TEMPLATE_FILE") + + NEW_BODY="${MARKER} + ${TEMPLATE_CONTENT}" + + gh pr edit ${{ github.event.pull_request.number }} --body "$NEW_BODY" + echo "Updated PR body with ${PR_TYPE} checklist" + else + echo "::error::Template file not found: $TEMPLATE_FILE" + echo "Something as gone wrong. Contact a maintainer." + exit 1 + fi + + - name: Comment on PR + uses: mshick/add-pr-comment@b8f338c590a895d50bcbfa6c5859251edc8952fc + with: + message: | + The PR description has been updated. Please fill out the template for your PR to be reviewed. diff --git a/README.md b/README.md index 684ec39c..76fe912f 100644 --- a/README.md +++ b/README.md @@ -174,7 +174,7 @@ the output is checked against the constraints using (in this case) LLM-as-a-judg ```python # file: https://github.com/generative-computing/mellea/blob/main/docs/examples/instruct_validate_repair/101_email_with_validate.py from mellea import MelleaSession -from mellea.backends.types import ModelOption +from mellea.backends import ModelOption from mellea.backends.ollama import OllamaModelBackend from mellea.backends import model_ids from mellea.stdlib.sampling import RejectionSamplingStrategy diff --git a/cli/decompose/pipeline.py b/cli/decompose/pipeline.py index 60ed9cf1..1c6dda1a 100644 --- a/cli/decompose/pipeline.py +++ b/cli/decompose/pipeline.py @@ -7,7 +7,7 @@ from mellea import MelleaSession from mellea.backends.ollama import OllamaModelBackend from mellea.backends.openai import OpenAIBackend -from mellea.backends.types import ModelOption +from mellea.backends import ModelOption from .prompt_modules import ( constraint_extractor, diff --git a/cli/decompose/prompt_modules/constraint_extractor/_constraint_extractor.py b/cli/decompose/prompt_modules/constraint_extractor/_constraint_extractor.py index 47cf8bf5..43558cce 100644 --- a/cli/decompose/prompt_modules/constraint_extractor/_constraint_extractor.py +++ b/cli/decompose/prompt_modules/constraint_extractor/_constraint_extractor.py @@ -3,8 +3,8 @@ from typing import Any, TypeVar, final from mellea import MelleaSession -from mellea.backends.types import ModelOption -from mellea.stdlib.chat import Message +from mellea.backends import ModelOption +from mellea.stdlib.components import Message from .._prompt_modules import PromptModule, PromptModuleString from ._exceptions import BackendGenerationError, TagExtractionError diff --git a/cli/decompose/prompt_modules/general_instructions/_general_instructions.py b/cli/decompose/prompt_modules/general_instructions/_general_instructions.py index c640fb15..26b51c43 100644 --- a/cli/decompose/prompt_modules/general_instructions/_general_instructions.py +++ b/cli/decompose/prompt_modules/general_instructions/_general_instructions.py @@ -3,8 +3,8 @@ from typing import Any, TypeVar, final from mellea import MelleaSession -from mellea.backends.types import ModelOption -from mellea.stdlib.chat import Message +from mellea.backends import ModelOption +from mellea.stdlib.components import Message from .._prompt_modules import PromptModule, PromptModuleString from ._exceptions import BackendGenerationError, TagExtractionError diff --git a/cli/decompose/prompt_modules/subtask_constraint_assign/_subtask_constraint_assign.py b/cli/decompose/prompt_modules/subtask_constraint_assign/_subtask_constraint_assign.py index 8f89eebd..21aece9d 100644 --- a/cli/decompose/prompt_modules/subtask_constraint_assign/_subtask_constraint_assign.py +++ b/cli/decompose/prompt_modules/subtask_constraint_assign/_subtask_constraint_assign.py @@ -5,8 +5,8 @@ from typing_extensions import Unpack from mellea import MelleaSession -from mellea.backends.types import ModelOption -from mellea.stdlib.chat import Message +from mellea.backends import ModelOption +from mellea.stdlib.components import Message from .._prompt_modules import PromptModule, PromptModuleString from ._exceptions import BackendGenerationError, TagExtractionError diff --git a/cli/decompose/prompt_modules/subtask_list/_subtask_list.py b/cli/decompose/prompt_modules/subtask_list/_subtask_list.py index a842325a..bf5eed38 100644 --- a/cli/decompose/prompt_modules/subtask_list/_subtask_list.py +++ b/cli/decompose/prompt_modules/subtask_list/_subtask_list.py @@ -3,8 +3,8 @@ from typing import Any, TypeVar, final from mellea import MelleaSession -from mellea.backends.types import ModelOption -from mellea.stdlib.chat import Message +from mellea.backends import ModelOption +from mellea.stdlib.components import Message from .._prompt_modules import PromptModule, PromptModuleString from ._exceptions import ( @@ -15,7 +15,7 @@ from ._prompt import get_system_prompt, get_user_prompt from ._types import SubtaskItem -# from mellea.stdlib.requirement import Requirement +# from mellea.stdlib.requirements import requirement T = TypeVar("T") diff --git a/cli/decompose/prompt_modules/subtask_prompt_generator/_subtask_prompt_generator.py b/cli/decompose/prompt_modules/subtask_prompt_generator/_subtask_prompt_generator.py index 780fe26b..45982b7d 100644 --- a/cli/decompose/prompt_modules/subtask_prompt_generator/_subtask_prompt_generator.py +++ b/cli/decompose/prompt_modules/subtask_prompt_generator/_subtask_prompt_generator.py @@ -5,8 +5,8 @@ from typing_extensions import Unpack from mellea import MelleaSession -from mellea.backends.types import ModelOption -from mellea.stdlib.chat import Message +from mellea.backends import ModelOption +from mellea.stdlib.components import Message from .._prompt_modules import PromptModule, PromptModuleString from ._exceptions import BackendGenerationError, TagExtractionError diff --git a/cli/decompose/prompt_modules/validation_decision/_validation_decision.py b/cli/decompose/prompt_modules/validation_decision/_validation_decision.py index aacd29ca..3d7d39fa 100644 --- a/cli/decompose/prompt_modules/validation_decision/_validation_decision.py +++ b/cli/decompose/prompt_modules/validation_decision/_validation_decision.py @@ -3,8 +3,8 @@ from typing import Any, Final, Literal, TypeVar, final from mellea import MelleaSession -from mellea.backends.types import ModelOption -from mellea.stdlib.chat import Message +from mellea.backends import ModelOption +from mellea.stdlib.components import Message from .._prompt_modules import PromptModule, PromptModuleString from ._exceptions import BackendGenerationError, TagExtractionError diff --git a/cli/eval/runner.py b/cli/eval/runner.py index 199581f1..91dda41b 100644 --- a/cli/eval/runner.py +++ b/cli/eval/runner.py @@ -4,9 +4,9 @@ from typing import List import mellea -from mellea.stdlib.base import ModelOutputThunk +from mellea.core import ModelOutputThunk from mellea.stdlib.test_based_eval import TestBasedEval -from mellea.backends.types import ModelOption +from mellea.backends import ModelOption from rich.console import Console from rich.progress import BarColumn, Progress, SpinnerColumn, TextColumn diff --git a/docs/dev/requirement_aLoRA_rerouting.md b/docs/dev/requirement_aLoRA_rerouting.md index 011073bd..74794f8a 100644 --- a/docs/dev/requirement_aLoRA_rerouting.md +++ b/docs/dev/requirement_aLoRA_rerouting.md @@ -33,7 +33,8 @@ Suppose that the user creates a backend and then adds a generic constraint check ```python from mellea import start_session -from mellea.stdlib.requirement import Requirement +from mellea.core import Requirement +from mellea.backends.adapters import GraniteCommonAdapter m = start_session( "huggingface.LocalHFBackend:ibm-granite/granite-3.2-8b-instruct") diff --git a/docs/examples/agents/react.py b/docs/examples/agents/react.py index 27099592..117f1440 100644 --- a/docs/examples/agents/react.py +++ b/docs/examples/agents/react.py @@ -8,14 +8,9 @@ from jinja2 import Template import mellea -import mellea.backends -import mellea.backends.types -import mellea.stdlib -import mellea.stdlib.base -import mellea.stdlib.chat -from mellea.backends import model_ids -from mellea.helpers.fancy_logger import FancyLogger -from mellea.stdlib.base import ChatContext +import mellea.stdlib.components.chat +from mellea.core import FancyLogger +from mellea.stdlib.context import ChatContext FancyLogger.get_logger().setLevel("ERROR") @@ -120,8 +115,8 @@ def react( # Add the system prompt and the goal to the chat history. m.ctx = m.ctx.add( - mellea.stdlib.chat.Message(role="system", content=_sys_prompt) - ).add(mellea.stdlib.chat.Message(role="user", content=f"{goal}")) + mellea.stdlib.components.chat.Message(role="system", content=_sys_prompt) + ).add(mellea.stdlib.components.chat.Message(role="user", content=f"{goal}")) # The main ReACT loop as a dynamic program: # ( ?(not done) ; @@ -162,7 +157,9 @@ def react( print("### Observation") tool_output = react_toolbox.call_tool(selected_tool, act_args.content) - m.ctx = m.ctx.add(mellea.stdlib.chat.Message(role="tool", content=tool_output)) + m.ctx = m.ctx.add( + mellea.stdlib.components.chat.Message(role="tool", content=tool_output) + ) print(tool_output) print("### Done Check") diff --git a/docs/examples/agents/react_instruct.py b/docs/examples/agents/react_instruct.py index 5102c650..b72adbc6 100644 --- a/docs/examples/agents/react_instruct.py +++ b/docs/examples/agents/react_instruct.py @@ -8,10 +8,8 @@ from jinja2 import Template import mellea -import mellea.stdlib -import mellea.stdlib.base -import mellea.stdlib.chat -from mellea.stdlib.base import ChatContext +import mellea.stdlib.components.chat +from mellea.stdlib.context import ChatContext react_system_template: Template = Template( """Answer the user's question as best you can. @@ -114,8 +112,8 @@ def react( # Add the system prompt and the goal to the chat history. m.ctx = m.ctx.add( - mellea.stdlib.chat.Message(role="system", content=_sys_prompt) - ).add(mellea.stdlib.chat.Message(role="user", content=f"{goal}")) + mellea.stdlib.components.chat.Message(role="system", content=_sys_prompt) + ).add(mellea.stdlib.components.chat.Message(role="user", content=f"{goal}")) # The main ReACT loop as a dynamic program: # ( ?(not done) ; @@ -161,7 +159,9 @@ def react( print("### Observation") tool_output = react_toolbox.call_tool(selected_tool, act_args_val) - m.ctx = m.ctx.add(mellea.stdlib.chat.Message(role="tool", content=tool_output)) + m.ctx = m.ctx.add( + mellea.stdlib.components.chat.Message(role="tool", content=tool_output) + ) print(tool_output) print("### Done Check") diff --git a/docs/examples/best_of_n/prm.py b/docs/examples/best_of_n/prm.py deleted file mode 100644 index 4da6d156..00000000 --- a/docs/examples/best_of_n/prm.py +++ /dev/null @@ -1,47 +0,0 @@ -"""Example of Using Best of N with PRMs.""" - -from docs.examples.helper import w -from mellea import start_session -from mellea.backends.model_ids import IBM_GRANITE_3_3_8B -from mellea.backends.process_reward_models.huggingface.prms import ( - HFGenerativePRM, - HFRegressionPRM, -) -from mellea.backends.types import ModelOption -from mellea.stdlib.rewards.prm_scorer import PRMScorer -from mellea.stdlib.sampling.best_of_n import BestofNSamplingStrategy - -# create a session for the generator using Granite 3.3 8B on Huggingface and a simple context [see below] -m = start_session( - backend_name="hf", - model_id=IBM_GRANITE_3_3_8B, - model_options={ModelOption.MAX_NEW_TOKENS: 512}, -) - -# initialize the PRM model -prm_model = HFGenerativePRM( - model_name_or_path="ibm-granite/granite-3.3-8b-lora-math-prm", - score_token="Y", - generation_prompt="Is this response correct so far (Y/N)?", - step_separator="\n\n", -) - -# # can also initialize a Regression PRM model -# prm_model = HFRegressionPRM( -# model_name_or_path = "granite-3.3-8b-math-prm-regression", -# score_token= "", -# step_separator= "\n\n") - -# create PRM scorer object -prm = PRMScorer(prm_model=prm_model, preference_ordering="max") - -# Do Best of N sampling with the PRM scorer and an additional requirement -BoN_prm = m.instruct( - "Sarah has 12 apples. She gives 5 of them to her friend. How many apples does Sarah have left?", - strategy=BestofNSamplingStrategy(loop_budget=3), - model_options={"temperature": 0.9, "do_sample": True}, - requirements=["provide final answer like 'Final Answer:'", prm], -) - -# print result -print(f"***** BoN ****\n{w(BoN_prm)}\n*******") diff --git a/docs/examples/context/contexts_with_sampling.py b/docs/examples/context/contexts_with_sampling.py index d760ca2a..1f71397b 100644 --- a/docs/examples/context/contexts_with_sampling.py +++ b/docs/examples/context/contexts_with_sampling.py @@ -1,5 +1,4 @@ -from mellea.backends.types import ModelOption -from mellea.stdlib.sampling.base import RejectionSamplingStrategy +from mellea.stdlib.sampling import RejectionSamplingStrategy from mellea.stdlib.session import start_session # You can retrieve context information when using SamplingStrategies @@ -38,7 +37,7 @@ # We can see the context that created this output. gen_ctx = res.sample_contexts[index] -print(f"Previous step in generating this result was: {gen_ctx.previous_node.node_data}") +print(f"Previous step in generating this result was: {gen_ctx.previous_node.node_data}") # type: ignore print() # We can also see what the validation context looked like. @@ -48,4 +47,4 @@ ) val_ctx = val_result.context -print(f"Output of the validation for this requirement: {val_ctx.node_data}") +print(f"Output of the validation for this requirement: {val_ctx.node_data}") # type: ignore diff --git a/docs/examples/generative_slots/generate_with_context.py b/docs/examples/generative_slots/generate_with_context.py index 98050523..8d4d3c87 100644 --- a/docs/examples/generative_slots/generate_with_context.py +++ b/docs/examples/generative_slots/generate_with_context.py @@ -1,6 +1,7 @@ from mellea import generative, start_session -from mellea.backends.types import ModelOption -from mellea.stdlib.base import CBlock, ChatContext +from mellea.backends import ModelOption +from mellea.core import CBlock +from mellea.stdlib.context import ChatContext # Generative slots can be used with sessions that have context. # By utilizing context, you can change the results of several diff --git a/docs/examples/generative_slots/generative_slots_with_requirements.py b/docs/examples/generative_slots/generative_slots_with_requirements.py index 7304bc3c..6f5a610a 100644 --- a/docs/examples/generative_slots/generative_slots_with_requirements.py +++ b/docs/examples/generative_slots/generative_slots_with_requirements.py @@ -1,8 +1,9 @@ from typing import Literal from mellea import generative, start_session -from mellea.stdlib.genslot import PreconditionException -from mellea.stdlib.requirement import Requirement, simple_validate +from mellea.stdlib.components.genslot import PreconditionException +from mellea.stdlib.requirements import simple_validate +from mellea.core import Requirement from mellea.stdlib.sampling.base import RejectionSamplingStrategy diff --git a/docs/examples/helper/helpers.py b/docs/examples/helper/helpers.py index f6b412f4..ad8a5a3d 100644 --- a/docs/examples/helper/helpers.py +++ b/docs/examples/helper/helpers.py @@ -1,7 +1,7 @@ from textwrap import fill from typing import Any -from mellea.stdlib.requirement import Requirement, ValidationResult +from mellea.core import Requirement, ValidationResult # Just for printing stuff nicely... diff --git a/docs/examples/image_text_models/vision_litellm_backend.py b/docs/examples/image_text_models/vision_litellm_backend.py index 69741dc9..03180fb2 100644 --- a/docs/examples/image_text_models/vision_litellm_backend.py +++ b/docs/examples/image_text_models/vision_litellm_backend.py @@ -8,7 +8,7 @@ from mellea import MelleaSession, start_session from mellea.backends.litellm import LiteLLMBackend from mellea.backends.openai import OpenAIBackend -from mellea.stdlib.base import ImageBlock +from mellea.core import ImageBlock import pathlib # use LiteLLM to talk to Ollama or anthropic or..... diff --git a/docs/examples/image_text_models/vision_ollama_chat.py b/docs/examples/image_text_models/vision_ollama_chat.py index 21f236a5..49fb1198 100644 --- a/docs/examples/image_text_models/vision_ollama_chat.py +++ b/docs/examples/image_text_models/vision_ollama_chat.py @@ -4,7 +4,7 @@ from PIL import Image from mellea import start_session -from mellea.stdlib.base import ChatContext, ImageBlock +from mellea.stdlib.context import ChatContext m = start_session(model_id="granite3.2-vision", ctx=ChatContext()) # m = start_session(model_id="llava", ctx=ChatContext()) diff --git a/docs/examples/image_text_models/vision_openai_examples.py b/docs/examples/image_text_models/vision_openai_examples.py index 250e0696..1ca58658 100644 --- a/docs/examples/image_text_models/vision_openai_examples.py +++ b/docs/examples/image_text_models/vision_openai_examples.py @@ -6,7 +6,8 @@ from mellea import MelleaSession from mellea.backends.openai import OpenAIBackend -from mellea.stdlib.base import ChatContext, ImageBlock +from mellea.stdlib.context import ChatContext +from mellea.core import ImageBlock # # using anthropic AI model ... # anth_key = os.environ.get("ANTHROPIC_API_KEY") diff --git a/docs/examples/information_extraction/advanced_with_m_instruct.py b/docs/examples/information_extraction/advanced_with_m_instruct.py index a76258a5..d2678952 100644 --- a/docs/examples/information_extraction/advanced_with_m_instruct.py +++ b/docs/examples/information_extraction/advanced_with_m_instruct.py @@ -6,8 +6,9 @@ from mellea import start_session from mellea.backends import model_ids -from mellea.stdlib.requirement import check, simple_validate -from mellea.stdlib.sampling import RejectionSamplingStrategy, SamplingResult +from mellea.stdlib.requirements import check, simple_validate +from mellea.stdlib.sampling import RejectionSamplingStrategy +from mellea.core import SamplingResult # ref: https://www.nytimes.com/2012/05/20/world/world-leaders-at-us-meeting-urge-growth-not-austerity.html NYTimes_text = "CAMP DAVID, Md. — Leaders of the world's richest countries banded together on Saturday to press Germany to back more pro-growth policies to halt the deepening debt crisis in Europe, as President Obama for the first time gained widespread support for his argument that Europe, and the United States by extension, cannot afford Chancellor Angela Merkel's one-size-fits-all approach emphasizing austerity." diff --git a/docs/examples/instruct_validate_repair/101_email.py b/docs/examples/instruct_validate_repair/101_email.py index 4d86c485..eb0bb9a1 100644 --- a/docs/examples/instruct_validate_repair/101_email.py +++ b/docs/examples/instruct_validate_repair/101_email.py @@ -2,7 +2,7 @@ # helper function to wrap text from docs.examples.helper import w from mellea import start_session -from mellea.backends.types import ModelOption +from mellea.backends import ModelOption with start_session(model_options={ModelOption.MAX_NEW_TOKENS: 200}) as m: # write an email diff --git a/docs/examples/instruct_validate_repair/101_email_comparison.py b/docs/examples/instruct_validate_repair/101_email_comparison.py index ef6a2704..760cea76 100644 --- a/docs/examples/instruct_validate_repair/101_email_comparison.py +++ b/docs/examples/instruct_validate_repair/101_email_comparison.py @@ -1,6 +1,6 @@ from docs.examples.helper import w from mellea import start_session -from mellea.backends.types import ModelOption +from mellea.backends import ModelOption # create a session using Granite 4 Micro 3B on Ollama and a simple context [see below] m = start_session(model_options={ModelOption.MAX_NEW_TOKENS: 200}) diff --git a/docs/examples/instruct_validate_repair/101_email_with_requirements.py b/docs/examples/instruct_validate_repair/101_email_with_requirements.py index 9f885c40..5d9f21a8 100644 --- a/docs/examples/instruct_validate_repair/101_email_with_requirements.py +++ b/docs/examples/instruct_validate_repair/101_email_with_requirements.py @@ -1,6 +1,6 @@ from docs.examples.helper import w from mellea import start_session -from mellea.backends.types import ModelOption +from mellea.backends import ModelOption # create a session using Granite 4 Micro 3B on Ollama and a simple context [see below] m = start_session(model_options={ModelOption.MAX_NEW_TOKENS: 200}) diff --git a/docs/examples/instruct_validate_repair/101_email_with_validate.py b/docs/examples/instruct_validate_repair/101_email_with_validate.py index 24bbc5b4..a7a0e500 100644 --- a/docs/examples/instruct_validate_repair/101_email_with_validate.py +++ b/docs/examples/instruct_validate_repair/101_email_with_validate.py @@ -1,7 +1,7 @@ from docs.examples.helper import req_print, w from mellea import start_session from mellea.backends.model_ids import IBM_GRANITE_3_3_8B -from mellea.backends.types import ModelOption +from mellea.backends import ModelOption from mellea.stdlib.sampling import RejectionSamplingStrategy # create a session using Granite 4 Micro (3B) on Ollama and a simple context [see below] diff --git a/docs/examples/instruct_validate_repair/advanced_email_with_validate_function.py b/docs/examples/instruct_validate_repair/advanced_email_with_validate_function.py index 71a9d408..0ffe0d13 100644 --- a/docs/examples/instruct_validate_repair/advanced_email_with_validate_function.py +++ b/docs/examples/instruct_validate_repair/advanced_email_with_validate_function.py @@ -1,7 +1,8 @@ from docs.examples.helper import w from mellea import start_session -from mellea.backends.types import ModelOption -from mellea.stdlib.requirement import Requirement, simple_validate +from mellea.backends import ModelOption +from mellea.stdlib.requirements import simple_validate +from mellea.core import Requirement from mellea.stdlib.sampling import RejectionSamplingStrategy # create a session using Granite 4 Micro 3B on Ollama and a simple context [see below] diff --git a/docs/examples/intrinsics/answer_relevance.py b/docs/examples/intrinsics/answer_relevance.py index d0e7bbd1..f945c6dd 100644 --- a/docs/examples/intrinsics/answer_relevance.py +++ b/docs/examples/intrinsics/answer_relevance.py @@ -8,9 +8,9 @@ """ from mellea.backends.huggingface import LocalHFBackend -from mellea.stdlib.base import ChatContext, Document -from mellea.stdlib.chat import Message -from mellea.stdlib.intrinsics import rag +from mellea.stdlib.context import ChatContext +from mellea.stdlib.components import Message, Document +from mellea.stdlib.components.intrinsic import rag backend = LocalHFBackend(model_id="ibm-granite/granite-4.0-micro") diff --git a/docs/examples/intrinsics/answerability.py b/docs/examples/intrinsics/answerability.py index ce9ff069..6804c5d7 100644 --- a/docs/examples/intrinsics/answerability.py +++ b/docs/examples/intrinsics/answerability.py @@ -8,10 +8,9 @@ """ from mellea.backends.huggingface import LocalHFBackend -from mellea.stdlib.base import ChatContext, Document -from mellea.stdlib.chat import Message -from mellea.stdlib.intrinsics import rag - +from mellea.stdlib.context import ChatContext +from mellea.stdlib.components import Message, Document +from mellea.stdlib.components.intrinsic import rag backend = LocalHFBackend(model_id="ibm-granite/granite-4.0-micro") context = ChatContext().add(Message("assistant", "Hello there, how can I help you?")) diff --git a/docs/examples/intrinsics/citations.py b/docs/examples/intrinsics/citations.py index 09fe7724..74377091 100644 --- a/docs/examples/intrinsics/citations.py +++ b/docs/examples/intrinsics/citations.py @@ -8,9 +8,9 @@ """ from mellea.backends.huggingface import LocalHFBackend -from mellea.stdlib.base import ChatContext, Document -from mellea.stdlib.chat import Message -from mellea.stdlib.intrinsics import rag +from mellea.stdlib.context import ChatContext +from mellea.stdlib.components import Message, Document +from mellea.stdlib.components.intrinsic import rag import json diff --git a/docs/examples/intrinsics/context_relevance.py b/docs/examples/intrinsics/context_relevance.py index ff6c985d..470973e3 100644 --- a/docs/examples/intrinsics/context_relevance.py +++ b/docs/examples/intrinsics/context_relevance.py @@ -8,10 +8,9 @@ """ from mellea.backends.huggingface import LocalHFBackend -from mellea.stdlib.base import ChatContext, Document -from mellea.stdlib.chat import Message -from mellea.stdlib.intrinsics import rag - +from mellea.stdlib.context import ChatContext +from mellea.stdlib.components import Document +from mellea.stdlib.components.intrinsic import rag backend = LocalHFBackend(model_id="ibm-granite/granite-4.0-micro") context = ChatContext() diff --git a/docs/examples/intrinsics/hallucination_detection.py b/docs/examples/intrinsics/hallucination_detection.py index ed1838d7..271e76a3 100644 --- a/docs/examples/intrinsics/hallucination_detection.py +++ b/docs/examples/intrinsics/hallucination_detection.py @@ -8,9 +8,9 @@ """ from mellea.backends.huggingface import LocalHFBackend -from mellea.stdlib.base import ChatContext, Document -from mellea.stdlib.chat import Message -from mellea.stdlib.intrinsics import rag +from mellea.stdlib.context import ChatContext +from mellea.stdlib.components import Message, Document +from mellea.stdlib.components.intrinsic import rag import json diff --git a/docs/examples/intrinsics/intrinsics.py b/docs/examples/intrinsics/intrinsics.py index 1392c551..10ba4e97 100644 --- a/docs/examples/intrinsics/intrinsics.py +++ b/docs/examples/intrinsics/intrinsics.py @@ -1,10 +1,10 @@ from mellea.backends.huggingface import LocalHFBackend from mellea.backends.openai import OpenAIBackend, _ServerType from mellea.backends.adapters.adapter import AdapterType, GraniteCommonAdapter -from mellea.stdlib.base import ChatContext, ModelOutputThunk -from mellea.stdlib.chat import Message +from mellea.stdlib.context import ChatContext +from mellea.stdlib.components import Message import mellea.stdlib.functional as mfuncs -from mellea.stdlib.intrinsics.intrinsic import Intrinsic +from mellea.stdlib.components import Intrinsic # This is an example for how you would directly use intrinsics. See `mellea/stdlib/intrinsics/rag.py` # for helper functions. diff --git a/docs/examples/intrinsics/query_rewrite.py b/docs/examples/intrinsics/query_rewrite.py index 39aabac1..a95cadc7 100644 --- a/docs/examples/intrinsics/query_rewrite.py +++ b/docs/examples/intrinsics/query_rewrite.py @@ -8,10 +8,9 @@ """ from mellea.backends.huggingface import LocalHFBackend -from mellea.stdlib.base import ChatContext -from mellea.stdlib.chat import Message -from mellea.stdlib.intrinsics import rag - +from mellea.stdlib.context import ChatContext +from mellea.stdlib.components import Message +from mellea.stdlib.components.intrinsic import rag backend = LocalHFBackend(model_id="ibm-granite/granite-4.0-micro") context = ( diff --git a/docs/examples/library_interop/langchain_messages.py b/docs/examples/library_interop/langchain_messages.py index fbd65185..8d99720d 100644 --- a/docs/examples/library_interop/langchain_messages.py +++ b/docs/examples/library_interop/langchain_messages.py @@ -15,9 +15,9 @@ messages = convert_to_openai_messages(messages=messages) # Import Mellea. -from mellea.stdlib.chat import Message -from mellea.stdlib.base import ChatContext -from mellea.backends.types import ModelOption +from mellea.stdlib.components import Message +from mellea.stdlib.context import ChatContext +from mellea.backends import ModelOption from mellea import start_session # Mellea uses explicit contexts. Cast the OpenAI formatted messages into diff --git a/docs/examples/m_serve/best_of_n.py b/docs/examples/m_serve/best_of_n.py deleted file mode 100644 index aca43e44..00000000 --- a/docs/examples/m_serve/best_of_n.py +++ /dev/null @@ -1,40 +0,0 @@ -"""Example to run m serve.""" - -import os - -import pydantic - -import mellea -from mellea.stdlib.base import CBlock, Component, ModelOutputThunk -from mellea.stdlib.sampling import SamplingResult - - -class RankerResponse(pydantic.BaseModel): - best_choice: int - - -session = mellea.start_session() - - -def serve( - input: str, requirements: list[str] | None = None, model_options: None | dict = None -) -> ModelOutputThunk | SamplingResult | None: - N = int(os.environ["N"] if "N" in os.environ else "3") - attempts: dict[str, str | CBlock | Component] = { - str(i): session.instruct(input) for i in range(N) - } - attempts["query"] = input - ranking_output = session.instruct( - "Choose the best response to the user's query", - grounding_context=attempts, - format=RankerResponse, - ).value - if ranking_output is not None: - choice = int(RankerResponse.model_validate_json(ranking_output).best_choice) - print(f"selected {choice}") - assert 0 < choice < len(attempts) - res = attempts[str(choice - 1)] - assert isinstance(res, ModelOutputThunk) or isinstance(res, SamplingResult) - return res - else: - return None diff --git a/docs/examples/m_serve/m_serve_example_simple.py b/docs/examples/m_serve/m_serve_example_simple.py index 79bf6e68..f1dff480 100644 --- a/docs/examples/m_serve/m_serve_example_simple.py +++ b/docs/examples/m_serve/m_serve_example_simple.py @@ -4,9 +4,10 @@ import mellea from cli.serve.models import ChatMessage -from mellea.stdlib.base import ChatContext, ModelOutputThunk -from mellea.stdlib.requirement import Requirement, simple_validate -from mellea.stdlib.sampling import RejectionSamplingStrategy, SamplingResult +from mellea.stdlib.context import ChatContext +from mellea.core import ModelOutputThunk, SamplingResult, Requirement +from mellea.stdlib.requirements import simple_validate +from mellea.stdlib.sampling import RejectionSamplingStrategy session = mellea.start_session(ctx=ChatContext()) diff --git a/docs/examples/m_serve/pii_serve.py b/docs/examples/m_serve/pii_serve.py index 09a40254..356867e8 100644 --- a/docs/examples/m_serve/pii_serve.py +++ b/docs/examples/m_serve/pii_serve.py @@ -5,10 +5,9 @@ from cli.serve.models import ChatMessage import mellea from mellea.backends.model_ids import IBM_GRANITE_4_MICRO_3B -from mellea.stdlib.base import ModelOutputThunk -from mellea.stdlib.requirement import req, simple_validate +from mellea.core import ModelOutputThunk, SamplingResult +from mellea.stdlib.requirements import req, simple_validate from mellea.stdlib.sampling import RejectionSamplingStrategy -from mellea.stdlib.sampling.types import SamplingResult def has_potential_pii(text: str) -> bool: diff --git a/docs/examples/mcp/mcp_example.py b/docs/examples/mcp/mcp_example.py index 48c042b5..43d965dc 100644 --- a/docs/examples/mcp/mcp_example.py +++ b/docs/examples/mcp/mcp_example.py @@ -12,8 +12,8 @@ from mellea import MelleaSession from mellea.backends import ModelOption, model_ids from mellea.backends.ollama import OllamaModelBackend -from mellea.stdlib.base import ModelOutputThunk -from mellea.stdlib.requirement import Requirement, simple_validate +from mellea.core import ModelOutputThunk, Requirement +from mellea.stdlib.requirements import simple_validate from mellea.stdlib.sampling import RejectionSamplingStrategy # ################# diff --git a/docs/examples/melp/lazy.py b/docs/examples/melp/lazy.py index 0f212f25..3a3b9e5f 100644 --- a/docs/examples/melp/lazy.py +++ b/docs/examples/melp/lazy.py @@ -1,12 +1,10 @@ import asyncio -from mellea.stdlib.base import ( - SimpleContext, - Context, - CBlock, - ModelOutputThunk, - SimpleComponent, -) -from mellea.backends import Backend +from mellea.core import Context, CBlock, ModelOutputThunk + +from mellea.stdlib.components import SimpleComponent +from mellea.stdlib.context import SimpleContext + +from mellea.core import Backend from mellea.backends.ollama import OllamaModelBackend backend = OllamaModelBackend("granite4:latest") diff --git a/docs/examples/melp/lazy_fib.py b/docs/examples/melp/lazy_fib.py index feec18e0..e91a4a2b 100644 --- a/docs/examples/melp/lazy_fib.py +++ b/docs/examples/melp/lazy_fib.py @@ -1,13 +1,10 @@ import asyncio -from mellea.stdlib.base import ( - SimpleContext, - Context, - CBlock, - ModelOutputThunk, - SimpleComponent, -) -from mellea.stdlib.requirement import Requirement -from mellea.backends import Backend +from mellea.core import Context, CBlock, ModelOutputThunk + +from mellea.stdlib.components import SimpleComponent +from mellea.stdlib.context import SimpleContext + +from mellea.core import Backend from mellea.backends.ollama import OllamaModelBackend from typing import Tuple diff --git a/docs/examples/melp/lazy_fib_sample.py b/docs/examples/melp/lazy_fib_sample.py index 0bec2907..0224f4a3 100644 --- a/docs/examples/melp/lazy_fib_sample.py +++ b/docs/examples/melp/lazy_fib_sample.py @@ -1,13 +1,10 @@ import asyncio -from mellea.stdlib.base import ( - SimpleContext, - Context, - CBlock, - ModelOutputThunk, - SimpleComponent, -) -from mellea.stdlib.requirement import Requirement -from mellea.backends import Backend +from mellea.core import Context, CBlock, ModelOutputThunk + +from mellea.stdlib.components import SimpleComponent +from mellea.stdlib.context import SimpleContext + +from mellea.core import Backend from mellea.backends.ollama import OllamaModelBackend from typing import Tuple diff --git a/docs/examples/melp/simple_example.py b/docs/examples/melp/simple_example.py index 7ac1059b..7862027e 100644 --- a/docs/examples/melp/simple_example.py +++ b/docs/examples/melp/simple_example.py @@ -1,7 +1,7 @@ import asyncio -from mellea.stdlib.base import Context, CBlock, SimpleContext, ModelOutputThunk -from mellea.backends import Backend +from mellea.core import Context, CBlock, ModelOutputThunk, Backend from mellea.backends.ollama import OllamaModelBackend +from mellea.stdlib.context import SimpleContext async def main(backend: Backend, ctx: Context): diff --git a/docs/examples/melp/states.py b/docs/examples/melp/states.py index 2383bf4a..d8770c3a 100644 --- a/docs/examples/melp/states.py +++ b/docs/examples/melp/states.py @@ -1,8 +1,10 @@ -from mellea.stdlib.base import SimpleContext, Context, CBlock, SimpleComponent -from mellea.backends import Backend -from mellea.backends.ollama import OllamaModelBackend import asyncio +from mellea.core import Context, CBlock, Backend +from mellea.backends.ollama import OllamaModelBackend +from mellea.stdlib.context import SimpleContext +from mellea.stdlib.components import SimpleComponent + async def main(backend: Backend, ctx: Context): a_states = "Alaska,Arizona,Arkansas".split(",") diff --git a/docs/examples/mify/mify.py b/docs/examples/mify/mify.py index 98581a97..b84a29cc 100644 --- a/docs/examples/mify/mify.py +++ b/docs/examples/mify/mify.py @@ -1,5 +1,5 @@ -from mellea.stdlib.docs.richdocument import TableQuery -from mellea.stdlib.mify import MifiedProtocol, mify +from mellea.stdlib.components.docs import TableQuery +from mellea.stdlib.components.mify import MifiedProtocol, mify from mellea.stdlib.session import start_session diff --git a/docs/examples/mify/rich_document_advanced.py b/docs/examples/mify/rich_document_advanced.py index 9ee0caed..c3ba2b41 100644 --- a/docs/examples/mify/rich_document_advanced.py +++ b/docs/examples/mify/rich_document_advanced.py @@ -6,7 +6,7 @@ from docling_core.types.doc.document import DoclingDocument import mellea -from mellea.stdlib.base import ModelOutputThunk, TemplateRepresentation +from mellea.core import ModelOutputThunk, TemplateRepresentation # Use a `SimpleContext` so that each LLM call is independent. m = mellea.start_session(backend_name="hf") @@ -35,7 +35,7 @@ # 4. `Mellea` also provides a basic wrapper around this functionality to make # basic processing of documents easier. -from mellea.stdlib.docs.richdocument import RichDocument +from mellea.stdlib.components.docs import RichDocument # This creates a new `Mellea` RichDocument component that encapsulates all # the logic above along with some convenient helpers. diff --git a/docs/examples/mify/rich_table_execute_basic.py b/docs/examples/mify/rich_table_execute_basic.py index 40be7cd4..eabb791e 100644 --- a/docs/examples/mify/rich_table_execute_basic.py +++ b/docs/examples/mify/rich_table_execute_basic.py @@ -3,9 +3,9 @@ from mellea import start_session from mellea.backends import model_ids -from mellea.backends.types import ModelOption -from mellea.helpers.fancy_logger import FancyLogger -from mellea.stdlib.docs.richdocument import RichDocument, Table +from mellea.backends import ModelOption +from mellea.core import FancyLogger +from mellea.stdlib.components.docs import RichDocument, Table FancyLogger.get_logger().setLevel("ERROR") diff --git a/docs/examples/mini_researcher/researcher.py b/docs/examples/mini_researcher/researcher.py index 7f332325..3fe3182a 100644 --- a/docs/examples/mini_researcher/researcher.py +++ b/docs/examples/mini_researcher/researcher.py @@ -9,8 +9,9 @@ from mellea import MelleaSession from mellea.backends import model_ids from mellea.backends.ollama import OllamaModelBackend -from mellea.stdlib.requirement import Requirement, simple_validate -from mellea.stdlib.sampling import RejectionSamplingStrategy, SamplingResult +from mellea.stdlib.requirements import simple_validate +from mellea.stdlib.sampling import RejectionSamplingStrategy +from mellea.core import SamplingResult, Requirement # ############################# # Helper functions diff --git a/docs/examples/mobject/table.py b/docs/examples/mobject/table.py index cc7191c0..03788882 100644 --- a/docs/examples/mobject/table.py +++ b/docs/examples/mobject/table.py @@ -3,8 +3,7 @@ import pandas import mellea -from mellea.backends.model_ids import IBM_GRANITE_3_3_8B -from mellea.stdlib.mify import mify +from mellea.stdlib.components.mify import mify @mify(fields_include={"table"}, template="{{ table }}") diff --git a/docs/examples/notebooks/context_example.ipynb b/docs/examples/notebooks/context_example.ipynb index 70605429..ec5d03fa 100644 --- a/docs/examples/notebooks/context_example.ipynb +++ b/docs/examples/notebooks/context_example.ipynb @@ -85,7 +85,7 @@ "outputs": [], "source": [ "from mellea import start_session\n", - "from mellea.stdlib.base import ChatContext\n", + "from mellea.stdlib.context import ChatContext\n", "\n", "m = start_session(ctx=ChatContext())\n", "m.chat(\"Make up a math problem.\")\n", diff --git a/docs/examples/notebooks/document_mobject.ipynb b/docs/examples/notebooks/document_mobject.ipynb index 37ae1acb..55c7a2b7 100644 --- a/docs/examples/notebooks/document_mobject.ipynb +++ b/docs/examples/notebooks/document_mobject.ipynb @@ -81,7 +81,7 @@ }, "outputs": [], "source": [ - "from mellea.stdlib.docs.richdocument import RichDocument\n", + "from mellea.stdlib.components.docs import RichDocument\n", "\n", "rd = RichDocument.from_document_file(\"https://arxiv.org/pdf/1906.04043\")" ] @@ -102,7 +102,7 @@ }, "outputs": [], "source": [ - "from mellea.stdlib.docs.richdocument import Table\n", + "from mellea.stdlib.components.docs import Table\n", "\n", "table1: Table = rd.get_tables()[0]\n", "print(table1.to_markdown())" @@ -127,7 +127,7 @@ "outputs": [], "source": [ "from mellea import start_session\n", - "from mellea.backends.types import ModelOption\n", + "from mellea.backends import ModelOption\n", "\n", "m = start_session()\n", "for seed in [x * 12 for x in range(5)]:\n", diff --git a/docs/examples/notebooks/georgia_tech.ipynb b/docs/examples/notebooks/georgia_tech.ipynb index 22bc78e9..3b349881 100644 --- a/docs/examples/notebooks/georgia_tech.ipynb +++ b/docs/examples/notebooks/georgia_tech.ipynb @@ -44,7 +44,7 @@ "!uv pip install \"mellea[all]\" -q\n", "\n", "# Run docling once to download model weights.\n", - "from mellea.stdlib.docs.richdocument import RichDocument\n", + "from mellea.stdlib.components.docs.richdocument import RichDocument\n", "\n", "RichDocument.from_document_file(\"https://mellea.ai\")\n", "\n", @@ -114,7 +114,7 @@ "outputs": [], "source": [ "import mellea\n", - "from mellea.stdlib.requirement import check, req, simple_validate\n", + "from mellea.stdlib.requirements import check, req, simple_validate\n", "from mellea.stdlib.sampling import RejectionSamplingStrategy\n", "\n", "requirements = [\n", @@ -141,7 +141,7 @@ " return email_candidate.sample_generations[0].value\n", "\n", "\n", - "m = mellea.start_session()\n", + "m = mellea_org.start_session()\n", "print(\n", " write_email(\n", " m,\n", @@ -387,7 +387,7 @@ }, "outputs": [], "source": [ - "from mellea.stdlib.docs.richdocument import RichDocument\n", + "from mellea.stdlib.components.docs.richdocument import RichDocument\n", "\n", "rd = RichDocument.from_document_file(\"https://arxiv.org/pdf/1906.04043\")" ] @@ -411,7 +411,7 @@ }, "outputs": [], "source": [ - "from mellea.stdlib.docs.richdocument import Table\n", + "from mellea.stdlib.components.docs.richdocument import Table\n", "\n", "table1: Table = rd.get_tables()[0]\n", "print(table1.to_markdown())" @@ -438,7 +438,7 @@ "source": [ "from mellea.backends.model_ids import META_LLAMA_3_2_3B\n", "from mellea.backends.ollama import OllamaModelBackend\n", - "from mellea.backends.types import ModelOption\n", + "from mellea.backends import ModelOption\n", "\n", "# You can use multiple different models at the same time!\n", "m_llama = mellea.MelleaSession(backend=OllamaModelBackend(model_id=META_LLAMA_3_2_3B))\n", @@ -489,7 +489,7 @@ "\n", "import pandas as pd\n", "\n", - "from mellea.stdlib.mify import mify\n", + "from mellea.stdlib.components.mify import mify\n", "\n", "\n", "@mify(fields_include={\"table\"}, template=\"{{ table }}\")\n", diff --git a/docs/examples/notebooks/instruct_validate_repair.ipynb b/docs/examples/notebooks/instruct_validate_repair.ipynb index 7201e69a..14896c2b 100644 --- a/docs/examples/notebooks/instruct_validate_repair.ipynb +++ b/docs/examples/notebooks/instruct_validate_repair.ipynb @@ -85,7 +85,7 @@ "outputs": [], "source": [ "import mellea\n", - "from mellea.stdlib.requirement import check, req, simple_validate\n", + "from mellea.stdlib.requirements import check, req, simple_validate\n", "from mellea.stdlib.sampling import RejectionSamplingStrategy\n", "\n", "requirements = [\n", diff --git a/docs/examples/notebooks/m_serve_example.ipynb b/docs/examples/notebooks/m_serve_example.ipynb index 7fa6e6b0..871349f7 100644 --- a/docs/examples/notebooks/m_serve_example.ipynb +++ b/docs/examples/notebooks/m_serve_example.ipynb @@ -83,9 +83,10 @@ "\n", "import mellea\n", "from cli.serve.models import ChatMessage\n", - "from mellea.stdlib.base import ChatContext, ModelOutputThunk\n", - "from mellea.stdlib.requirement import Requirement, simple_validate\n", - "from mellea.stdlib.sampling import RejectionSamplingStrategy, SamplingResult\n", + "from mellea.stdlib.context import ChatContext\n", + "from mellea.core import ModelOutputThunk, Requirement, SamplingResult\n", + "from mellea.stdlib.requirements import simple_validate\n", + "from mellea.stdlib.sampling import RejectionSamplingStrategy\n", "\n", "session = mellea.start_session(ctx=ChatContext())\n", "\n", diff --git a/docs/examples/notebooks/mcp_example.ipynb b/docs/examples/notebooks/mcp_example.ipynb index 438037dd..50c6233b 100644 --- a/docs/examples/notebooks/mcp_example.ipynb +++ b/docs/examples/notebooks/mcp_example.ipynb @@ -85,8 +85,8 @@ "from mellea import MelleaSession\n", "from mellea.backends import model_ids\n", "from mellea.backends.ollama import OllamaModelBackend\n", - "from mellea.stdlib.base import ModelOutputThunk\n", - "from mellea.stdlib.requirement import Requirement, simple_validate\n", + "from mellea.core import ModelOutputThunk, Requirement\n", + "from mellea.stdlib.requirements import simple_validate\n", "from mellea.stdlib.sampling import RejectionSamplingStrategy" ] }, diff --git a/docs/examples/notebooks/model_options_example.ipynb b/docs/examples/notebooks/model_options_example.ipynb index a0d6330b..a706c05a 100644 --- a/docs/examples/notebooks/model_options_example.ipynb +++ b/docs/examples/notebooks/model_options_example.ipynb @@ -84,7 +84,7 @@ "import mellea\n", "from mellea.backends import model_ids\n", "from mellea.backends.ollama import OllamaModelBackend\n", - "from mellea.backends.types import ModelOption" + "from mellea.backends import ModelOption" ] }, { diff --git a/docs/examples/notebooks/table_mobject.ipynb b/docs/examples/notebooks/table_mobject.ipynb index aefb0f71..94289994 100644 --- a/docs/examples/notebooks/table_mobject.ipynb +++ b/docs/examples/notebooks/table_mobject.ipynb @@ -86,7 +86,7 @@ "import pandas\n", "\n", "import mellea\n", - "from mellea.stdlib.mify import MifiedProtocol, mify\n", + "from mellea.stdlib.components.mify import MifiedProtocol, mify\n", "\n", "\n", "@mify(fields_include={\"table\"}, template=\"{{ table }}\")\n", diff --git a/docs/examples/rag/mellea_pdf.py b/docs/examples/rag/mellea_pdf.py index dbc0eb0d..c0a64140 100644 --- a/docs/examples/rag/mellea_pdf.py +++ b/docs/examples/rag/mellea_pdf.py @@ -1,5 +1,5 @@ import mellea -from mellea.stdlib.docs.richdocument import RichDocument +from mellea.stdlib.components.docs import RichDocument m = mellea.start_session() diff --git a/docs/examples/safety/guardian.py b/docs/examples/safety/guardian.py index bf31e94e..c5b8b123 100644 --- a/docs/examples/safety/guardian.py +++ b/docs/examples/safety/guardian.py @@ -3,9 +3,10 @@ from mellea import MelleaSession from mellea.backends import model_ids from mellea.backends.ollama import OllamaModelBackend -from mellea.stdlib.base import ContextTurn, ModelOutputThunk, ChatContext -from mellea.stdlib.chat import Message -from mellea.stdlib.safety.guardian import GuardianCheck, GuardianRisk +from mellea.core import ContextTurn, ModelOutputThunk +from mellea.stdlib.context import ChatContext +from mellea.stdlib.components import Message +from mellea.stdlib.requirements.safety.guardian import GuardianCheck, GuardianRisk # Enhanced GuardianCheck with Granite Guardian 3.3 8B support print("=== Enhanced GuardianCheck Examples ===") @@ -99,7 +100,7 @@ print("\n=== Test 5: Function Call Hallucination Detection ===") # Test function calling hallucination using IBM video example -from mellea.stdlib.base import ModelOutputThunk, ModelToolCall +from mellea.core import ModelOutputThunk, ModelToolCall tools = [ { diff --git a/docs/examples/safety/guardian_huggingface.py b/docs/examples/safety/guardian_huggingface.py index 3cc3d507..bbb84698 100644 --- a/docs/examples/safety/guardian_huggingface.py +++ b/docs/examples/safety/guardian_huggingface.py @@ -8,9 +8,10 @@ from mellea.backends import model_ids from mellea.backends.ollama import OllamaModelBackend from mellea.backends.huggingface import LocalHFBackend -from mellea.stdlib.base import ChatContext, ModelOutputThunk, ModelToolCall -from mellea.stdlib.chat import Message -from mellea.stdlib.safety.guardian import GuardianCheck, GuardianRisk +from mellea.core import ModelOutputThunk, ModelToolCall +from mellea.stdlib.context import ChatContext +from mellea.stdlib.components import Message +from mellea.stdlib.requirements.safety.guardian import GuardianCheck, GuardianRisk print("=== GuardianCheck HuggingFace Backend Example ===") diff --git a/docs/examples/safety/repair_with_guardian.py b/docs/examples/safety/repair_with_guardian.py index c2c1d20a..2355eff5 100644 --- a/docs/examples/safety/repair_with_guardian.py +++ b/docs/examples/safety/repair_with_guardian.py @@ -5,7 +5,7 @@ from mellea import MelleaSession from mellea.backends.ollama import OllamaModelBackend -from mellea.stdlib.safety.guardian import GuardianCheck, GuardianRisk +from mellea.stdlib.requirements.safety.guardian import GuardianCheck, GuardianRisk from mellea.stdlib.sampling import RepairTemplateStrategy diff --git a/docs/examples/sessions/creating_a_new_type_of_session.py b/docs/examples/sessions/creating_a_new_type_of_session.py index a665cf64..59624caf 100644 --- a/docs/examples/sessions/creating_a_new_type_of_session.py +++ b/docs/examples/sessions/creating_a_new_type_of_session.py @@ -2,18 +2,13 @@ from PIL import Image as PILImage from mellea import MelleaSession -from mellea.backends import Backend, BaseModelSubclass +from mellea.core import Backend, BaseModelSubclass from mellea.backends.ollama import OllamaModelBackend -from mellea.stdlib.base import ( - CBlock, - ChatContext, - Context, - ImageBlock, - ModelOutputThunk, -) -from mellea.stdlib.chat import Message -from mellea.stdlib.requirement import Requirement, reqify -from mellea.stdlib.safety.guardian import GuardianCheck, GuardianRisk +from mellea.core import CBlock, Context, ImageBlock, Requirement +from mellea.stdlib.context import ChatContext +from mellea.stdlib.components import Message +from mellea.stdlib.requirements import reqify +from mellea.stdlib.requirements.safety.guardian import GuardianCheck, GuardianRisk # This example shows how you might go about creating a new type of session. # Here, we want to filter out potentially harmful chat messages from the user. diff --git a/docs/examples/tools/interpreter_example.py b/docs/examples/tools/interpreter_example.py index ccde0ada..b2a9315b 100644 --- a/docs/examples/tools/interpreter_example.py +++ b/docs/examples/tools/interpreter_example.py @@ -1,8 +1,7 @@ from mellea.stdlib.tools import code_interpreter, local_code_interpreter from mellea import start_session, MelleaSession -from mellea.backends.types import ModelOption -from mellea.backends.model_ids import OPENAI_GPT_OSS_20B -from mellea.stdlib.reqlib.tools import uses_tool, tool_arg_validator +from mellea.backends import ModelOption +from mellea.stdlib.requirements import uses_tool, tool_arg_validator def example_1(m: MelleaSession): diff --git a/docs/examples/tutorial/document_mobject.py b/docs/examples/tutorial/document_mobject.py index 9c6a88a4..19d04bf7 100644 --- a/docs/examples/tutorial/document_mobject.py +++ b/docs/examples/tutorial/document_mobject.py @@ -1,16 +1,16 @@ from mellea.backends import model_ids from mellea.backends.model_ids import IBM_GRANITE_3_3_8B -from mellea.stdlib.docs.richdocument import RichDocument +from mellea.stdlib.components.docs import RichDocument rd = RichDocument.from_document_file("https://arxiv.org/pdf/1906.04043") -from mellea.stdlib.docs.richdocument import Table # noqa: E402 +from mellea.stdlib.components.docs import Table # noqa: E402 table1: Table = rd.get_tables()[0] print(table1.to_markdown()) from mellea import start_session # noqa: E402 -from mellea.backends.types import ModelOption # noqa: E402 +from mellea.backends import ModelOption # noqa: E402 m = start_session(model_id=model_ids.META_LLAMA_3_2_3B) for seed in [x * 12 for x in range(5)]: diff --git a/docs/examples/tutorial/instruct_validate_repair.py b/docs/examples/tutorial/instruct_validate_repair.py index e3f55f60..76113d0f 100644 --- a/docs/examples/tutorial/instruct_validate_repair.py +++ b/docs/examples/tutorial/instruct_validate_repair.py @@ -1,4 +1,4 @@ -from mellea.stdlib.requirement import check, req, simple_validate +from mellea.stdlib.requirements import check, req, simple_validate requirements = [ req("The email should have a salutation"), # == r1 diff --git a/docs/examples/tutorial/mcp_example.py b/docs/examples/tutorial/mcp_example.py index 48c042b5..25bf5199 100644 --- a/docs/examples/tutorial/mcp_example.py +++ b/docs/examples/tutorial/mcp_example.py @@ -12,8 +12,8 @@ from mellea import MelleaSession from mellea.backends import ModelOption, model_ids from mellea.backends.ollama import OllamaModelBackend -from mellea.stdlib.base import ModelOutputThunk -from mellea.stdlib.requirement import Requirement, simple_validate +from mellea.core import ModelOutputThunk +from mellea.stdlib.requirements import requirement, simple_validate from mellea.stdlib.sampling import RejectionSamplingStrategy # ################# diff --git a/docs/examples/tutorial/model_options_example.py b/docs/examples/tutorial/model_options_example.py index a583cdc4..7eb88b9a 100644 --- a/docs/examples/tutorial/model_options_example.py +++ b/docs/examples/tutorial/model_options_example.py @@ -1,7 +1,7 @@ import mellea from mellea.backends import model_ids from mellea.backends.ollama import OllamaModelBackend -from mellea.backends.types import ModelOption +from mellea.backends import ModelOption m = mellea.MelleaSession( backend=OllamaModelBackend(model_options={ModelOption.SEED: 42}) diff --git a/docs/examples/tutorial/table_mobject.py b/docs/examples/tutorial/table_mobject.py index 68d33cfd..1fa0eb34 100644 --- a/docs/examples/tutorial/table_mobject.py +++ b/docs/examples/tutorial/table_mobject.py @@ -3,7 +3,7 @@ import pandas import mellea -from mellea.stdlib.mify import MifiedProtocol, mify +from mellea.stdlib.components.mify import MifiedProtocol, mify @mify(fields_include={"table"}, template="{{ table }}") diff --git a/docs/kv_smash/hf_example.py b/docs/kv_smash/hf_example.py index 9a82921b..dc81a64a 100644 --- a/docs/kv_smash/hf_example.py +++ b/docs/kv_smash/hf_example.py @@ -1,8 +1,9 @@ from mellea.backends.huggingface import LocalHFBackend from mellea.backends.model_ids import IBM_GRANITE_3_3_8B -from mellea.backends.types import ModelOption -from mellea.stdlib.base import CBlock, ChatContext -from mellea.stdlib.chat import Message +from mellea.backends import ModelOption +from mellea.core import CBlock +from mellea.stdlib.context import ChatContext +from mellea.stdlib.components import Message import asyncio diff --git a/docs/kv_smash/kv_with_chat.py b/docs/kv_smash/kv_with_chat.py index a40afb60..e0f43bc4 100644 --- a/docs/kv_smash/kv_with_chat.py +++ b/docs/kv_smash/kv_with_chat.py @@ -3,8 +3,6 @@ from mellea.backends.huggingface import LocalHFBackend from mellea.backends.kv_block_helpers import DynamicCache, merge_dynamic_caches from mellea.backends.model_ids import IBM_GRANITE_3_3_8B -from mellea.stdlib.base import CBlock, ChatContext -from mellea.stdlib.chat import Message backend = LocalHFBackend(model_id=IBM_GRANITE_3_3_8B) diff --git a/docs/rewrite/session_deepdive/0.py b/docs/rewrite/session_deepdive/0.py index ae95a18a..82d15776 100644 --- a/docs/rewrite/session_deepdive/0.py +++ b/docs/rewrite/session_deepdive/0.py @@ -1,10 +1,8 @@ from mellea import MelleaSession -from mellea.stdlib.base import SimpleContext +from mellea.stdlib.context import SimpleContext from mellea.backends.ollama import OllamaModelBackend -m = MelleaSession( - backend=OllamaModelBackend("granite4:latest"), context=SimpleContext() -) +m = MelleaSession(backend=OllamaModelBackend("granite4:latest"), ctx=SimpleContext()) response = m.chat("What is 1+1?") print(response.content) diff --git a/docs/rewrite/session_deepdive/1.py b/docs/rewrite/session_deepdive/1.py index 1886a0bb..c12e7df6 100644 --- a/docs/rewrite/session_deepdive/1.py +++ b/docs/rewrite/session_deepdive/1.py @@ -1,5 +1,5 @@ import mellea.stdlib.functional as mfuncs -from mellea.stdlib.base import SimpleContext +from mellea.stdlib.context import SimpleContext from mellea.backends.ollama import OllamaModelBackend response, next_context = mfuncs.chat( diff --git a/docs/rewrite/session_deepdive/2.py b/docs/rewrite/session_deepdive/2.py index e20346e2..8002128d 100644 --- a/docs/rewrite/session_deepdive/2.py +++ b/docs/rewrite/session_deepdive/2.py @@ -1,5 +1,6 @@ import mellea.stdlib.functional as mfuncs -from mellea.stdlib.base import SimpleContext, CBlock +from mellea.stdlib.context import SimpleContext +from mellea.core import CBlock from mellea.backends.ollama import OllamaModelBackend response, next_context = mfuncs.act( diff --git a/docs/rewrite/session_deepdive/3.py b/docs/rewrite/session_deepdive/3.py index c43797bd..1d522f77 100644 --- a/docs/rewrite/session_deepdive/3.py +++ b/docs/rewrite/session_deepdive/3.py @@ -1,7 +1,7 @@ import mellea.stdlib.functional as mfuncs -from mellea.stdlib.base import SimpleContext, CBlock, Context +from mellea.core import CBlock, Context, Backend +from mellea.stdlib.context import SimpleContext from mellea.backends.ollama import OllamaModelBackend -from mellea.backends import Backend import asyncio diff --git a/docs/rewrite/session_deepdive/4.py b/docs/rewrite/session_deepdive/4.py index 298d4a44..9cd71dd8 100644 --- a/docs/rewrite/session_deepdive/4.py +++ b/docs/rewrite/session_deepdive/4.py @@ -1,7 +1,6 @@ -import mellea.stdlib.functional as mfuncs -from mellea.stdlib.base import SimpleContext, CBlock, Context +from mellea.core import CBlock, Context, Backend from mellea.backends.ollama import OllamaModelBackend -from mellea.backends import Backend +from mellea.stdlib.context import SimpleContext import asyncio diff --git a/docs/rewrite/session_deepdive/5.py b/docs/rewrite/session_deepdive/5.py index e5d024ae..f03369c4 100644 --- a/docs/rewrite/session_deepdive/5.py +++ b/docs/rewrite/session_deepdive/5.py @@ -1,13 +1,8 @@ -import mellea.stdlib.functional as mfuncs -from mellea.stdlib.base import ( - SimpleContext, - CBlock, - Context, - SimpleComponent, - Component, -) +from mellea.core import CBlock, Context, Backend from mellea.backends.ollama import OllamaModelBackend -from mellea.backends import Backend +from mellea.stdlib.components import SimpleComponent +from mellea.stdlib.context import SimpleContext + import asyncio diff --git a/docs/tutorial.md b/docs/tutorial.md index bc8f6628..cdf3ee26 100644 --- a/docs/tutorial.md +++ b/docs/tutorial.md @@ -191,7 +191,7 @@ Let's look on how we can customize requirement definitions: ```python # file: https://github.com/generative-computing/mellea/blob/main/docs/examples/tutorial/instruct_validate_repair.py#L1-L10 -from mellea.stdlib.requirement import req, check, simple_validate +from mellea.stdlib.requirements import req, check, simple_validate requirements = [ req("The email should have a salutation"), # == r1 @@ -218,7 +218,7 @@ Now, we bring it all together into a first generative program using the **instru ```python # file: https://github.com/generative-computing/mellea/blob/main/docs/examples/tutorial/instruct_validate_repair.py#L13-L37 import mellea -from mellea.stdlib.requirement import req, check, simple_validate +from mellea.stdlib.requirements import req, check, simple_validate from mellea.stdlib.sampling import RejectionSamplingStrategy def write_email(m: mellea.MelleaSession, name: str, notes: str) -> str: @@ -262,7 +262,7 @@ You can add any key-value pair supported by the backend to the `model_options` d ```python # file: https://github.com/generative-computing/mellea/blob/main/docs/examples/tutorial/model_options_example.py#L1-L16 import mellea -from mellea.backends.types import ModelOption +from mellea.backends import ModelOption from mellea.backends.ollama import OllamaModelBackend from mellea.backends import model_ids @@ -570,7 +570,7 @@ Suppose you have a table of sales data and want to let the LLM answer questions ```python # file: https://github.com/generative-computing/mellea/blob/main/docs/examples/tutorial/table_mobject.py#L1-L31 import mellea -from mellea.stdlib.mify import mify, MifiedProtocol +from mellea.stdlib.components.mify import mify, MifiedProtocol import pandas from io import StringIO @@ -617,7 +617,7 @@ Let's create a RichDocument from an arxiv paper: ```python # file: https://github.com/generative-computing/mellea/blob/main/docs/examples/tutorial/document_mobject.py#L1-L3 -from mellea.stdlib.docs.richdocument import RichDocument +from mellea.stdlib.components.docs import RichDocument rd = RichDocument.from_document_file("https://arxiv.org/pdf/1906.04043") ``` this loads the PDF file and parses it using the Docling parser into an @@ -627,7 +627,7 @@ From the rich document we can extract some document content, e.g. the first table: ```python # file: https://github.com/generative-computing/mellea/blob/main/docs/examples/tutorial/document_mobject.py#L5-L8 -from mellea.stdlib.docs.richdocument import Table +from mellea.stdlib.components.docs import Table table1: Table = rd.get_tables()[0] print(table1.to_markdown()) ``` @@ -646,7 +646,7 @@ The `Table` object is Mellea-ready and can be used immediately with LLMs. Let's just get it to work: ```python # file: https://github.com/generative-computing/mellea/blob/main/docs/examples/tutorial/document_mobject.py#L10-L24 -from mellea.backends.types import ModelOption +from mellea.backends import ModelOption from mellea import start_session m = start_session() @@ -693,7 +693,7 @@ The model has done a great job at fulfilling the task and coming back with a par When an object is `mified` all methods with a docstring get registered as tools for the LLM call. You can control if you only want a subset of these functions to be exposed by two parameters (`funcs_include` and `funcs_exclude`): ```python -from mellea.stdlib.mify import mify +from mellea.stdlib.components.mify import mify @mify(funcs_include={"from_markdown"}) class MyDocumentLoader: @@ -1319,7 +1319,7 @@ For examples on adding tools to the template representation of a component, see Here's an example of adding a tool through model options. This can be useful when you want to add a tool like web search that should almost always be available: ```python -from mellea.backends.types import ModelOption +from mellea.backends import ModelOption def web_search(query: str) -> str: ... diff --git a/mellea/__init__.py b/mellea/__init__.py index a8fc24fa..3fd46b10 100644 --- a/mellea/__init__.py +++ b/mellea/__init__.py @@ -1,7 +1,7 @@ -"""Mellea is a library for building robust LLM applications.""" +"""Mellea.""" -import mellea.backends.model_ids as model_ids -from mellea.stdlib.genslot import generative -from mellea.stdlib.session import MelleaSession, start_session +from .backends import model_ids +from .stdlib.components.genslot import generative +from .stdlib.session import MelleaSession, start_session __all__ = ["MelleaSession", "generative", "model_ids", "start_session"] diff --git a/mellea/backends/__init__.py b/mellea/backends/__init__.py index 2ebb0059..a8e7ded6 100644 --- a/mellea/backends/__init__.py +++ b/mellea/backends/__init__.py @@ -1,154 +1,8 @@ -"""Backends (e.g., ollama, huggingface, openai-compatible) communicate with LLMs.""" - -from __future__ import annotations - -import abc -import asyncio -import itertools -from collections.abc import Sequence -from typing import Any, overload - -import pydantic -import typing_extensions - -from mellea.backends.model_ids import ModelIdentifier -from mellea.backends.types import ModelOption -from mellea.helpers.fancy_logger import FancyLogger -from mellea.stdlib.base import C, CBlock, Component, Context, ModelOutputThunk - -# Necessary to define a type that supports `None` so that the BaseModelSubclass -# can have a default value. Otherwise, Python complains about typed-components -# since types with default values must come after those without default values in -# function signatures (which is incompatible with our function parameter formatting). -pydantic_model_or_none = pydantic.BaseModel | None -BaseModelSubclass = typing_extensions.TypeVar( - "BaseModelSubclass", bound=pydantic_model_or_none, default=None -) # must be a subclass of BaseModel - - -class Backend(abc.ABC): - """An abstract `Backend`.""" - - def __init__( - self, model_id: str | ModelIdentifier, *, model_options: dict | None = None - ): - """All backends need to be instantiated with a `model_id`. - - A backend can support multiple models, but each instance of a backend corresponds to exactly one model. - - Args: - model_id (str | ModelIdentifier): The model_id for this model. - model_options (Optional[dict]): If set, these model options will be used. Otherwise an empty model options dictionary will be used. - """ - self.model_id = model_id - self.model_options = model_options if model_options is not None else {} - - @abc.abstractmethod - async def generate_from_context( - self, - action: Component[C] | CBlock, - ctx: Context, - *, - format: type[BaseModelSubclass] | None = None, - model_options: dict | None = None, - tool_calls: bool = False, - ) -> tuple[ModelOutputThunk[C], Context]: - """Generates a model output from a context. May not mutate the context. This must be called from a running event loop as it creates a task to run the generation request. - - Args: - action: The last item of the context should be passed in as an `action` instead of as part of the `ctx`. See `docs/dev/generate_signature_decisions.md`. - ctx: The rest of the context. - format: A response format to used for structured outputs / constrained decoding. - model_options: Any model options to upsert into the defaults for this call. - tool_calls: If `True`, then tool calls are extracts from the `action` `Component`. Assumption: if tool_calls is enabled, then the action `Component` has a TemplateRepresentation - - Returns: - a tuple of (ModelOutputThunk, Context) where the Context is the new context after the generation has been completed. - """ - ... - - @overload - async def generate_from_raw( - self, - actions: list[Component[C]], - ctx: Context, - *, - format: type[BaseModelSubclass] | None = None, - model_options: dict | None = None, - tool_calls: bool = False, - ) -> list[ModelOutputThunk[C]]: ... - - @overload - async def generate_from_raw( - self, - actions: list[Component[C] | CBlock], - ctx: Context, - *, - format: type[BaseModelSubclass] | None = None, - model_options: dict | None = None, - tool_calls: bool = False, - ) -> list[ModelOutputThunk[C | str]]: ... - - @abc.abstractmethod - async def generate_from_raw( - self, - actions: Sequence[Component[C] | CBlock], - ctx: Context, - *, - format: type[BaseModelSubclass] | None = None, - model_options: dict | None = None, - tool_calls: bool = False, - ) -> list[ModelOutputThunk]: - """Generates a model output from the provided input. Does not use context or templates. - - Args: - actions: list of actions to generate responses for. Each action is separate. - ctx: context passed to generation. Currently not used in generate_from_raw - format: A response format to used for structured outputs / constrained decoding. Note: some backends do not support this parameter. They will log warnings and continue to generate. - model_options: Any model options to upsert into the defaults for this call. - tool_calls: Always set to false unless supported by backend. - """ - - async def do_generate_walk( - self, action: CBlock | Component | ModelOutputThunk - ) -> None: - """Does the generation walk.""" - _to_compute = list(generate_walk(action)) - coroutines = [x.avalue() for x in _to_compute] - # The following log message might get noisy. Feel free to remove if so. - if len(_to_compute) > 0: - FancyLogger.get_logger().info( - f"generate_from_chat_context awaited on {len(_to_compute)} uncomputed mots." - ) - await asyncio.gather(*coroutines) - - async def do_generate_walks( - self, actions: list[CBlock | Component | ModelOutputThunk] - ) -> None: - """Does the generation walk.""" - _to_compute = [] - for action in actions: - _to_compute.extend(list(generate_walk(action))) - coroutines = [x.avalue() for x in _to_compute] - # The following log message might get noisy. Feel free to remove if so. - if len(_to_compute) > 0: - FancyLogger.get_logger().info( - f"generate_from_chat_context awaited on {len(_to_compute)} uncomputed mots." - ) - await asyncio.gather(*coroutines) - - -def generate_walk(c: CBlock | Component | ModelOutputThunk) -> list[ModelOutputThunk]: - """Returns the generation walk ordering for a Span.""" - match c: - case ModelOutputThunk() if not c.is_computed(): - return [c] - case CBlock(): - return [] - case Component(): - parts_walk = [generate_walk(p) for p in c.parts()] - return list(itertools.chain.from_iterable(parts_walk)) # aka flatten - case _: - raise ValueError( - f"parts should only contain CBlocks, Components, or ModelOutputThunks; found `{c!s:.10}{'...' if len(str(c)) > 10 else ''}` (type: {type(c)})" - ) +"""Backend implementations.""" + +# Import from core for ergonomics. +from ..core import Backend, BaseModelSubclass +from .backend import FormatterBackend +from .cache import SimpleLRUCache +from .model_ids import ModelIdentifier +from .model_options import ModelOption diff --git a/mellea/backends/adapters/__init__.py b/mellea/backends/adapters/__init__.py new file mode 100644 index 00000000..367503be --- /dev/null +++ b/mellea/backends/adapters/__init__.py @@ -0,0 +1,11 @@ +"""Classes and Functions for Backend Adapters.""" + +from .adapter import ( + AdapterMixin, + AdapterType, + GraniteCommonAdapter, + LocalHFAdapter, + OpenAIAdapter, + fetch_intrinsic_metadata, + get_adapter_for_intrinsic, +) diff --git a/mellea/backends/adapters/adapter.py b/mellea/backends/adapters/adapter.py index 21993e85..46973f1a 100644 --- a/mellea/backends/adapters/adapter.py +++ b/mellea/backends/adapters/adapter.py @@ -7,9 +7,9 @@ import granite_common.intrinsics import yaml -from mellea.backends import Backend -from mellea.backends.adapters.catalog import AdapterType, fetch_intrinsic_metadata -from mellea.backends.types import _ServerType +from ...core import Backend +from ...helpers import _ServerType +from .catalog import AdapterType, fetch_intrinsic_metadata class Adapter(abc.ABC): diff --git a/mellea/backends/backend.py b/mellea/backends/backend.py new file mode 100644 index 00000000..210a70aa --- /dev/null +++ b/mellea/backends/backend.py @@ -0,0 +1,41 @@ +"""FormatterBackend.""" + +import abc +from enum import Enum +from urllib.parse import urlparse + +from ..core import Backend +from ..formatters import ChatFormatter +from .model_ids import ModelIdentifier + + +class FormatterBackend(Backend, abc.ABC): + """`FormatterBackend`s support legacy model types. + + The `mellea` library was designed to support generative computing with [spanned attention](https://generative.computing/what-are-spans.html) over [generative programming primitives](https://generative.computing/what-are-generative-programs.html). + In the ideal world, context management is handled via span scope-relations and all generative programming primitives are baked into the model via fine-tuning. + I.e., the model's instruction tuning is done in terms of generative programming primitives, and the model is then prompted with the same set of templates that were used for that tuning. + + Today, most models do not yet support spans and even those that do are not properly tuned to leverage generative programming primitives. + The `mellea` library supports these legacy models primarily through prompt engineering surfaced via `FormatterBackends`. + A `FormatterBackend` is a backend that uses hand-engineered prompts for rendering generative programming primitives to a model and parsing responses from the model back into `mellea`. + By default, a `FormatterBackend` uses jinja2 templates for pretty-printing, and relies on the user's ad-hoc logic for parsing. + """ + + def __init__( + self, + model_id: str | ModelIdentifier, + formatter: ChatFormatter, + *, + model_options: dict | None = None, + ): + """Initializes a formatter-based backend for `model_id`. + + Args: + model_id (str): The model_id to use. + formatter (Formatter): The formatter to use for converting components into (fragments of) prompts. + model_options (Optional[dict]): The model options to use; if None, sensible defaults will be provided. + """ + self.model_id = model_id + self.model_options = model_options if model_options is not None else {} + self.formatter: ChatFormatter = formatter diff --git a/mellea/backends/dummy.py b/mellea/backends/dummy.py index 3b45999d..71812d9a 100644 --- a/mellea/backends/dummy.py +++ b/mellea/backends/dummy.py @@ -1,7 +1,14 @@ """This module holds shim backends used for smoke tests.""" -from mellea.backends import Backend, BaseModelSubclass -from mellea.stdlib.base import C, CBlock, Component, Context, ModelOutputThunk +from ..core import ( + Backend, + BaseModelSubclass, + C, + CBlock, + Component, + Context, + ModelOutputThunk, +) class DummyBackend(Backend): diff --git a/mellea/backends/huggingface.py b/mellea/backends/huggingface.py index 7dca2cac..48e74543 100644 --- a/mellea/backends/huggingface.py +++ b/mellea/backends/huggingface.py @@ -5,68 +5,62 @@ from __future__ import annotations -import abc import asyncio import dataclasses import datetime import functools -import inspect import json import threading from collections.abc import Callable, Coroutine, Sequence -from copy import deepcopy -from typing import TYPE_CHECKING, Any, TypeVar, cast, overload +from typing import Any, overload import granite_common import outlines import outlines_core import peft import torch -from transformers import ( - AsyncTextIteratorStreamer, - AutoModelForCausalLM, - AutoTokenizer, - DynamicCache, - PreTrainedModel, - PreTrainedTokenizer, - set_seed, -) +from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers.cache_utils import DynamicCache +from transformers.generation.streamers import AsyncTextIteratorStreamer from transformers.generation.utils import GenerateDecoderOnlyOutput +from transformers.modeling_utils import PreTrainedModel +from transformers.tokenization_utils import PreTrainedTokenizer +from transformers.trainer_utils import set_seed -from mellea.backends import BaseModelSubclass, kv_block_helpers -from mellea.backends._utils import to_chat, to_tool_calls -from mellea.backends.adapters.adapter import ( +from ..backends import kv_block_helpers +from ..core import ( + BaseModelSubclass, + C, + CBlock, + Component, + Context, + FancyLogger, + GenerateLog, + GenerateType, + ModelOutputThunk, + Requirement, +) +from ..formatters import ChatFormatter, TemplateFormatter +from ..helpers import message_to_openai_message, messages_to_docs, send_to_queue +from ..stdlib.components import Intrinsic, Message +from ..stdlib.requirements import ALoraRequirement, LLMaJRequirement +from .adapters import ( AdapterMixin, AdapterType, GraniteCommonAdapter, LocalHFAdapter, get_adapter_for_intrinsic, ) -from mellea.backends.cache import Cache, SimpleLRUCache -from mellea.backends.formatter import Formatter, FormatterBackend, TemplateFormatter -from mellea.backends.model_ids import ModelIdentifier -from mellea.backends.openai import OpenAIBackend -from mellea.backends.process_reward_models import PRM -from mellea.backends.tools import ( +from .backend import FormatterBackend +from .cache import Cache, SimpleLRUCache +from .model_ids import ModelIdentifier +from .model_options import ModelOption +from .tools import ( add_tools_from_context_actions, add_tools_from_model_options, convert_tools_to_json, ) -from mellea.backends.types import ModelOption -from mellea.helpers.async_helpers import send_to_queue -from mellea.helpers.fancy_logger import FancyLogger -from mellea.stdlib.base import ( - C, - CBlock, - Component, - Context, - GenerateLog, - GenerateType, - ModelOutputThunk, -) -from mellea.stdlib.chat import Message -from mellea.stdlib.intrinsics.intrinsic import Intrinsic -from mellea.stdlib.requirement import ALoraRequirement, LLMaJRequirement, Requirement +from .utils import to_chat, to_tool_calls assert outlines, "outlines needs to be present to make outlines_core work" @@ -102,7 +96,7 @@ class LocalHFBackend(FormatterBackend, AdapterMixin): def __init__( self, model_id: str | ModelIdentifier, - formatter: Formatter | None = None, + formatter: ChatFormatter | None = None, *, use_caches: bool = True, cache: Cache | None = None, @@ -313,11 +307,9 @@ async def _generate_from_intrinsic( if system_prompt != "": conversation.append({"role": "system", "content": system_prompt}) - conversation.extend( - [OpenAIBackend.message_to_openai_message(m) for m in ctx_as_message_list] - ) + conversation.extend([message_to_openai_message(m) for m in ctx_as_message_list]) - docs = OpenAIBackend.messages_to_docs(ctx_as_message_list) + docs = messages_to_docs(ctx_as_message_list) seed = model_options.get(ModelOption.SEED, None) if seed is not None: @@ -373,7 +365,7 @@ async def _generate_from_intrinsic( # us having specific caching for each Component/Message. generate_input, other_input = ( - granite_common.util.chat_completion_request_to_transformers_inputs( + granite_common.util.chat_completion_request_to_transformers_inputs( # type: ignore rewritten, self._tokenizer, self._model ) ) @@ -381,7 +373,7 @@ async def _generate_from_intrinsic( chat_response = asyncio.to_thread( self._generate_with_adapter_lock, adapter.qualified_name, - granite_common.util.generate_with_transformers, + granite_common.util.generate_with_transformers, # type: ignore # Passed as args/kwargs to generate. self._tokenizer, self._model, @@ -605,7 +597,7 @@ async def _generate_from_context_with_kv_cache( from outlines.models.transformers import TransformerTokenizer from outlines.processors.structured import RegexLogitsProcessor - from transformers import LogitsProcessorList + from transformers import LogitsProcessorList # type: ignore format_kwargs["logits_processor"] = LogitsProcessorList( [ @@ -772,7 +764,7 @@ async def _generate_from_context_standard( from outlines.models.transformers import TransformerTokenizer from outlines.processors.structured import RegexLogitsProcessor - from transformers import LogitsProcessorList + from transformers import LogitsProcessorList # type: ignore format_kwargs["logits_processor"] = LogitsProcessorList( [ @@ -1019,7 +1011,7 @@ async def generate_from_raw( from outlines.models.transformers import TransformerTokenizer from outlines.processors.structured import RegexLogitsProcessor - from transformers import LogitsProcessorList + from transformers import LogitsProcessorList # type: ignore format_kwargs["logits_processor"] = LogitsProcessorList( [ @@ -1281,56 +1273,3 @@ def _assert_correct_adapters(expected_state: str, model: PreTrainedModel): ) else: raise e - - -class HFProcessRewardModel(PRM, abc.ABC): - """A Process Reward Model that works with a huggingface backend.""" - - def __init__( - self, model_name_or_path: str, score_token: str, device: str | None = None - ): - """Initialize an PRM that works with a huggingface backend. Currently supports and tested with IBM Process Reward Models. - - Args: - model_name_or_path (str): A local path to PRM or a huggingface PRM - score_token (str): token who's logits correspond to the PRM score. Can be a step demarker (for non-generative PRMs) or a correctness indicator (for generative PRMs) - device (str): device: The computational device to use ("cuda" for GPU, "mps" for Apple Silicon, or "cpu"), defaults to None. If not specified, the best available device will be automatically selected. - """ - super().__init__(model_name_or_path) - - # auto-device if not more specific - self._device = device - if device is None: - device_name: str = ( - "cuda" - if torch.cuda.is_available() - else "mps" - if torch.backends.mps.is_available() - else "cpu" - ) - assert device_name is not None - self._device = torch.device(device_name) # type: ignore - - self.model: PreTrainedModel = AutoModelForCausalLM.from_pretrained( - self.model_name_or_path, torch_dtype=torch.bfloat16 - ) - self.model.eval() - self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path) - - self._score_token = score_token - self._score_token_id = self.tokenizer.encode( - self._score_token, add_special_tokens=False - )[0] - - def stepify(self, content: str, step_separator: str) -> list[str]: - """Splits the assistant response into steps to score. - - Args: - content: assistant response to score - step_separator: string on which to separate the content into steps - """ - # convert assistant message into a list of steps - list_of_steps = [ - step.strip() for step in content.split(step_separator) if step.strip != "" - ] - return list_of_steps diff --git a/mellea/backends/kv_block_helpers.py b/mellea/backends/kv_block_helpers.py index ef54b80b..90d6222b 100644 --- a/mellea/backends/kv_block_helpers.py +++ b/mellea/backends/kv_block_helpers.py @@ -5,7 +5,8 @@ from typing import Any import torch -from transformers import BatchEncoding, DynamicCache +from transformers.cache_utils import DynamicCache +from transformers.tokenization_utils_base import BatchEncoding TokenizedCacheIterleaving = Iterable[BatchEncoding | DynamicCache] LegacyCache = Any diff --git a/mellea/backends/litellm.py b/mellea/backends/litellm.py index a1de8951..4dc82da8 100644 --- a/mellea/backends/litellm.py +++ b/mellea/backends/litellm.py @@ -4,42 +4,43 @@ import datetime import functools import json -import os from collections.abc import Callable, Coroutine, Sequence from typing import Any, overload -import litellm # type: ignore -import litellm.litellm_core_utils # type: ignore -import litellm.litellm_core_utils.get_supported_openai_params # type: ignore +import litellm +import litellm.litellm_core_utils +import litellm.litellm_core_utils.get_supported_openai_params -import mellea.backends.model_ids as model_ids -from mellea.backends import BaseModelSubclass -from mellea.backends.formatter import Formatter, FormatterBackend, TemplateFormatter -from mellea.backends.openai import OpenAIBackend -from mellea.backends.tools import ( - add_tools_from_context_actions, - add_tools_from_model_options, - convert_tools_to_json, -) -from mellea.backends.types import ModelOption -from mellea.helpers.async_helpers import get_current_event_loop, send_to_queue -from mellea.helpers.fancy_logger import FancyLogger -from mellea.helpers.openai_compatible_helpers import ( - chat_completion_delta_merge, - extract_model_tool_requests, -) -from mellea.stdlib.base import ( +from ..backends import model_ids +from ..core import ( + BaseModelSubclass, C, CBlock, Component, Context, + FancyLogger, GenerateLog, GenerateType, ModelOutputThunk, ModelToolCall, ) -from mellea.stdlib.chat import Message -from mellea.stdlib.requirement import ALoraRequirement +from ..formatters import ChatFormatter, TemplateFormatter +from ..helpers import ( + chat_completion_delta_merge, + extract_model_tool_requests, + get_current_event_loop, + message_to_openai_message, + send_to_queue, +) +from ..stdlib.components import Message +from ..stdlib.requirements import ALoraRequirement +from .backend import FormatterBackend +from .model_options import ModelOption +from .tools import ( + add_tools_from_context_actions, + add_tools_from_model_options, + convert_tools_to_json, +) format: None = None # typing this variable in order to shadow the global format function and ensure mypy checks for errors @@ -51,7 +52,7 @@ def __init__( self, model_id: str = "ollama_chat/" + str(model_ids.IBM_GRANITE_4_MICRO_3B.ollama_name), - formatter: Formatter | None = None, + formatter: ChatFormatter | None = None, base_url: str | None = "http://localhost:11434", model_options: dict | None = None, ): @@ -269,9 +270,7 @@ async def _generate_from_chat_context_standard( system_prompt = model_opts.get(ModelOption.SYSTEM_PROMPT, "") if system_prompt != "": conversation.append({"role": "system", "content": system_prompt}) - conversation.extend( - [OpenAIBackend.message_to_openai_message(m) for m in messages] - ) + conversation.extend([message_to_openai_message(m) for m in messages]) extra_params: dict[str, Any] = {} if _format is not None: diff --git a/mellea/backends/model_ids.py b/mellea/backends/model_ids.py index 3ffdb4b8..90329caa 100644 --- a/mellea/backends/model_ids.py +++ b/mellea/backends/model_ids.py @@ -1,7 +1,6 @@ -"""Utilities for model identifiers.""" +"""Dataclasses for ModelIdentifiers.""" import dataclasses -from typing import Optional @dataclasses.dataclass(frozen=True) diff --git a/mellea/backends/types.py b/mellea/backends/model_options.py similarity index 82% rename from mellea/backends/types.py rename to mellea/backends/model_options.py index 3ee0571d..e464855b 100644 --- a/mellea/backends/types.py +++ b/mellea/backends/model_options.py @@ -1,10 +1,8 @@ -"""Useful type definitions for models, formatters, and backends.""" +"""Common ModelOptions for Backend Generation.""" -from enum import Enum from typing import Any -from urllib.parse import urlparse -from mellea.helpers.fancy_logger import FancyLogger +from ..core import FancyLogger class ModelOption: @@ -15,7 +13,7 @@ class ModelOption: Create a dictionary containing model options like this: ```python - from mellea.backends.types import ModelOption + from mellea.backends import ModelOption model_options = { ModelOption.TEMPERATURE : 0.0, ModelOption.SYSTEM_PROMPT : "You are a helpful assistant" @@ -116,27 +114,3 @@ def merge_model_options( for k, v in overwrite_opts.items(): new_options[k] = v return new_options - - -class _ServerType(Enum): - """Different types of servers that might be relevant for a backend.""" - - UNKNOWN = 0 - LOCALHOST = 1 - OPENAI = 2 - REMOTE_VLLM = 3 - """Must be set manually for now.""" - - -def _server_type(url: str) -> _ServerType: - """Find a server type based on the url.""" - try: - parsed = urlparse(url) - hostname = parsed.hostname - if hostname in ("localhost", "127.0.0.1", "::1", "0.0.0.0"): - return _ServerType.LOCALHOST - elif hostname == "api.openai.com": - return _ServerType.OPENAI - except Exception as e: - print(f"Error parsing URL: {e}") - return _ServerType.UNKNOWN diff --git a/mellea/backends/ollama.py b/mellea/backends/ollama.py index 94bfb2cc..1c0c9a20 100644 --- a/mellea/backends/ollama.py +++ b/mellea/backends/ollama.py @@ -9,34 +9,26 @@ import ollama from tqdm import tqdm -import mellea.backends.model_ids as model_ids -from mellea.backends import BaseModelSubclass, generate_walk -from mellea.backends.formatter import Formatter, FormatterBackend, TemplateFormatter -from mellea.backends.model_ids import ModelIdentifier -from mellea.backends.tools import ( - add_tools_from_context_actions, - add_tools_from_model_options, -) -from mellea.backends.types import ModelOption -from mellea.helpers.async_helpers import ( - ClientCache, - get_current_event_loop, - send_to_queue, -) -from mellea.helpers.event_loop_helper import _run_async_in_thread -from mellea.helpers.fancy_logger import FancyLogger -from mellea.stdlib.base import ( +from ..backends import ModelIdentifier, model_ids +from ..core import ( + BaseModelSubclass, C, CBlock, Component, Context, + FancyLogger, GenerateLog, GenerateType, ModelOutputThunk, ModelToolCall, ) -from mellea.stdlib.chat import Message -from mellea.stdlib.requirement import ALoraRequirement +from ..formatters import ChatFormatter, TemplateFormatter +from ..helpers import ClientCache, get_current_event_loop, send_to_queue +from ..stdlib.components import Message +from ..stdlib.requirements import ALoraRequirement +from .backend import FormatterBackend +from .model_options import ModelOption +from .tools import add_tools_from_context_actions, add_tools_from_model_options format: None = None # typing this variable in order to shadow the global format function and ensure mypy checks for errors @@ -47,7 +39,7 @@ class OllamaModelBackend(FormatterBackend): def __init__( self, model_id: str | ModelIdentifier = model_ids.IBM_GRANITE_4_MICRO_3B, - formatter: Formatter | None = None, + formatter: ChatFormatter | None = None, base_url: str | None = None, model_options: dict | None = None, ): @@ -434,12 +426,7 @@ async def generate_from_raw( model_opts = self._simplify_and_merge(model_options) - _to_compute = [] - for act in actions: - _to_compute.extend(generate_walk(act)) - parts_coroutines = [x.avalue() for x in _to_compute] - await asyncio.gather(*parts_coroutines) - + await self.do_generate_walks(list(actions)) prompts = [self.formatter.print(action) for action in actions] # Ollama doesn't support "batching". There's some ability for concurrency. Use that here. diff --git a/mellea/backends/openai.py b/mellea/backends/openai.py index f6c3c21b..d7eb284d 100644 --- a/mellea/backends/openai.py +++ b/mellea/backends/openai.py @@ -1,17 +1,12 @@ """A generic OpenAI compatible backend that wraps around the openai python sdk.""" -import abc import asyncio import datetime import functools import inspect -import json import os from collections.abc import Callable, Coroutine, Sequence -from copy import deepcopy -from enum import Enum from typing import TYPE_CHECKING, Any, overload -from urllib.parse import urlparse import granite_common import openai @@ -20,46 +15,47 @@ from openai.types.chat.chat_completion_chunk import ChatCompletionChunk from openai.types.completion import Completion -import mellea.backends.model_ids as model_ids -from mellea.backends import BaseModelSubclass -from mellea.backends.adapters.adapter import ( +from ..backends import ModelIdentifier, model_ids +from ..core import ( + BaseModelSubclass, + C, + CBlock, + Component, + Context, + FancyLogger, + GenerateLog, + GenerateType, + ModelOutputThunk, + Requirement, +) +from ..formatters import ChatFormatter, TemplateFormatter +from ..helpers import ( + ClientCache, + _server_type, + _ServerType, + chat_completion_delta_merge, + extract_model_tool_requests, + get_current_event_loop, + message_to_openai_message, + messages_to_docs, + send_to_queue, +) +from ..stdlib.components import Intrinsic, Message +from ..stdlib.requirements import ALoraRequirement, LLMaJRequirement +from .adapters import ( AdapterMixin, AdapterType, GraniteCommonAdapter, OpenAIAdapter, get_adapter_for_intrinsic, ) -from mellea.backends.formatter import Formatter, FormatterBackend, TemplateFormatter -from mellea.backends.model_ids import ModelIdentifier -from mellea.backends.tools import ( +from .backend import FormatterBackend +from .model_options import ModelOption +from .tools import ( add_tools_from_context_actions, add_tools_from_model_options, convert_tools_to_json, ) -from mellea.backends.types import ModelOption, _server_type, _ServerType -from mellea.helpers.async_helpers import ( - ClientCache, - get_current_event_loop, - send_to_queue, -) -from mellea.helpers.fancy_logger import FancyLogger -from mellea.helpers.openai_compatible_helpers import ( - chat_completion_delta_merge, - extract_model_tool_requests, -) -from mellea.stdlib.base import ( - C, - CBlock, - Component, - Context, - Document, - GenerateLog, - GenerateType, - ModelOutputThunk, -) -from mellea.stdlib.chat import Message -from mellea.stdlib.intrinsics.intrinsic import Intrinsic -from mellea.stdlib.requirement import ALoraRequirement, LLMaJRequirement, Requirement if TYPE_CHECKING: from transformers.tokenization_utils import PreTrainedTokenizer @@ -75,7 +71,7 @@ class OpenAIBackend(FormatterBackend, AdapterMixin): def __init__( self, model_id: str | ModelIdentifier = model_ids.OPENAI_GPT_5_1, - formatter: Formatter | None = None, + formatter: ChatFormatter | None = None, base_url: str | None = None, model_options: dict | None = None, *, @@ -413,8 +409,8 @@ async def _generate_from_intrinsic( system_prompt = model_opts.get(ModelOption.SYSTEM_PROMPT, "") if system_prompt != "": conversation.append({"role": "system", "content": system_prompt}) - conversation.extend([self.message_to_openai_message(m) for m in messages]) - docs = self.messages_to_docs(messages) + conversation.extend([message_to_openai_message(m) for m in messages]) + docs = messages_to_docs(messages) if model_opts.get(ModelOption.STREAM, None) is not None: # Intrinsics don't support streaming because of their post-processing step. @@ -525,59 +521,6 @@ async def granite_common_processing( return output - @staticmethod - def message_to_openai_message(msg: Message): - """Serializes a mellea Message object to the message format required by OpenAI compatible api providers.""" - if msg.images is not None: - img_list = [ - { - "type": "image_url", - "image_url": {"url": f"data:image/png;base64,{img}"}, - } - for img in msg.images - ] - - return { - "role": msg.role, - "content": [{"type": "text", "text": msg.content}, *img_list], - } - else: - return {"role": msg.role, "content": msg.content} - # Target format: - # { - # "role": "user", - # "content": [ - # { - # "type": "text", - # "text": "What's in this picture?" - # }, - # { - # "type": "image_url", - # "image_url": { - # "url": "data:image/jpeg;base64," - # } - # } - # ] - # } - - @staticmethod - def messages_to_docs(msgs: list[Message]) -> list[dict[str, str]]: - """Extracts the docs from a list of messages.""" - docs: list[Document] = [] - for message in msgs: - if message._docs is not None: - docs.extend(message._docs) - - json_docs: list[dict[str, str]] = [] - for doc in docs: - json_doc = {"text": doc.text} - if doc.title is not None: - json_doc["title"] = doc.title - if doc.doc_id is not None: - json_doc["doc_id"] = doc.doc_id - json_docs.append(json_doc) - return json_docs - async def _generate_from_chat_context_standard( self, action: Component | CBlock, @@ -610,7 +553,7 @@ async def _generate_from_chat_context_standard( system_prompt = model_opts.get(ModelOption.SYSTEM_PROMPT, "") if system_prompt != "": conversation.append({"role": "system", "content": system_prompt}) - conversation.extend([self.message_to_openai_message(m) for m in messages]) + conversation.extend([message_to_openai_message(m) for m in messages]) extra_params: dict[str, Any] = {} if _format is not None: diff --git a/mellea/backends/process_reward_models/__init__.py b/mellea/backends/process_reward_models/__init__.py deleted file mode 100644 index 5097d5f3..00000000 --- a/mellea/backends/process_reward_models/__init__.py +++ /dev/null @@ -1,27 +0,0 @@ -"""Abstract interfaces for Backends that implement Process Reward Models (can be adapted to include other scorers).""" - -import abc - - -class PRM(abc.ABC): - """Mixin for Process Reward Model Backends.""" - - def __init__(self, model_name_or_path): - """Sets the self.model_name_or_path. Inheriting classes should implement the remaining logic.""" - # Leave implementation of model to inheriting class - self.model_name_or_path = model_name_or_path - - @abc.abstractmethod - def score(self, query: str, response: str) -> tuple[list[float], list[list[float]]]: - """Returns a final score and per-step score to the input of the model.""" - ... - - @abc.abstractmethod - def stepify(self, response: str, step_separator: str) -> list[str]: - """Splits the assistant response into steps to score. - - Args: - response: assistant response to score - step_separator: string on which to separate the response into steps - """ - ... diff --git a/mellea/backends/process_reward_models/huggingface/__init__.py b/mellea/backends/process_reward_models/huggingface/__init__.py deleted file mode 100644 index 35adf9ac..00000000 --- a/mellea/backends/process_reward_models/huggingface/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Process Reward Model Implementations with Huggingface backends.""" diff --git a/mellea/backends/process_reward_models/huggingface/prms.py b/mellea/backends/process_reward_models/huggingface/prms.py deleted file mode 100644 index 2525b8e6..00000000 --- a/mellea/backends/process_reward_models/huggingface/prms.py +++ /dev/null @@ -1,258 +0,0 @@ -"""PRM Implementations for Local HuggingFace Backends.""" - -import torch -from transformers.tokenization_utils_base import BatchEncoding - -from mellea.backends.huggingface import HFProcessRewardModel - - -class HFGenerativePRM(HFProcessRewardModel): - """A Generative PRM that works with a huggingface backend.""" - - def __init__( - self, - model_name_or_path: str = "ibm-granite/granite-3.3-8b-lora-math-prm", - score_token: str = "Y", - device: str | None = None, - generation_prompt: str = "Is this response correct so far (Y/N)?", - step_separator: str = "\n\n", - ): - """Initialize a Generative PRM that works with a huggingface backend. Currently supports and tested with IBM Process Reward Models. - - Args: - model_name_or_path (str): A local path to PRM or a huggingface PRM - score_token (str): token who's logits correspond to the PRM score. Usually is a correctness indicator (for generative PRMs) - device (str): pointer to device - generation_prompt (str): Optional prompt to be added before generation - step_separator (str): string on which to separate the content into steps - """ - super().__init__(model_name_or_path, score_token, device) - self.generation_prompt = ( - generation_prompt if generation_prompt is not None else "" - ) - self.step_separator = step_separator - self.softmax = torch.nn.Softmax(dim=-1) - - def score(self, query: str, response: str) -> tuple[list[float], list[list[float]]]: - """Returns a final and per-step score for a given input query and response. - - Args: - query (str): User query - response (str): Assistant Response to score - """ - list_of_steps = self.stepify(response, self.step_separator) - # get tokenized batch - batches = self.prepare_inputs(query, list_of_steps) - all_rewards = [] - all_rewards_per_step = [] - - # find the chat turn where assistant message starts to find the correct placement of the score token - # add empty system prompt to prevent model from adding its own system prompt - chat_template_to_turn = self.tokenizer.apply_chat_template( - [ - {"role": "system", "content": ""}, - {"role": "assistant", "content": self._score_token}, - ], - tokenize=False, - add_generation_prompt=False, - ) - # removing the system prompt by finding the assistant turn, which usually starts like <|..|>assistant<|..> - asst_text = chat_template_to_turn[chat_template_to_turn.find(">assistant<") :][ - 1: - ] - asst_toks = self.tokenizer( - asst_text, add_special_tokens=False, return_tensors="pt" - )["input_ids"][0] - asst_toks_before_correct_token = asst_toks[ - : torch.where(asst_toks == self._score_token_id)[ - 0 - ].item() # type: ignore - ].tolist() # type : ignore - - # move each item of the batch to the device - for i in batches: - batches[i] = batches[i].to(self.model.device) - - with torch.no_grad(): - model_outputs = self.model(**batches) - logits = model_outputs.logits # (bsz, seq_len, vocab_size) - - for batch_idx in range(logits.shape[0]): - per_input_rewards = [] - # for each element in the batch (i.e., each input) - # we need to get logits for all tokens where the token is self._score_token (in assistant turn) - # find batch index for **assistant** turn is self._score_token, not just the self._score_token_id - correct_token_indices = torch.where( - batches["input_ids"][batch_idx] == self._score_token_id - )[0].tolist() - prm_indices = [] - for t_idx in correct_token_indices: - if ( - batches["input_ids"][batch_idx][ - t_idx - len(asst_toks_before_correct_token) : t_idx - ].tolist() - == asst_toks_before_correct_token - ): - prm_indices.append( - t_idx - 1 - ) # the logits for token i predict the token i+1: so, we need to look at the **previous** token logits - - assert len(prm_indices) > 0 - # convert logits to probabilities and get the probability of the correct token id as reward - for prm_idx in prm_indices: - per_input_rewards.append( - self.softmax(logits[batch_idx, prm_idx, :])[ - self._score_token_id - ].item() - ) - - # aggregate. return final rewards - all_rewards_per_step.append(per_input_rewards) - sum = 0 - for reward in per_input_rewards: - sum += reward - per_input_reward = sum / len(per_input_rewards) - all_rewards.append(per_input_reward) - - return all_rewards, all_rewards_per_step - - def prepare_inputs(self, user_content: str, steps: list[str]) -> BatchEncoding: - """Prepare the inputs for inference with the model. - - Args: - user_content (str): the user query - steps (List(str)): assistant response, broken down into steps - """ - msgs = [] - for s_idx, step in enumerate(steps): - # apply chat template as expected by the reward model - # rewards are calculated from the logit of self._score_token as produced by the assistant - if s_idx == 0: - msgs.append( - { - "role": "user", - "content": user_content - + " " - + step - + " " - + self.generation_prompt, - } - ) - else: - # first add last assistant turn - msgs.append({"role": "assistant", "content": self._score_token}) - msgs.append( - {"role": "user", "content": step + " " + self.generation_prompt} - ) - - # append last assistant turn - msgs.append({"role": "assistant", "content": self._score_token}) - input_message = self.tokenizer.apply_chat_template( - msgs, add_generation_prompt=False, tokenize=False - ) - return self.tokenizer( - [input_message], return_tensors="pt", padding=True, truncation=True - ) - - -class HFRegressionPRM(HFProcessRewardModel): - """A Regression PRM that works with a huggingface backend.""" - - def __init__( - self, - model_name_or_path: str, - score_token: str = "", - device: str | None = None, - step_separator: str = "\n\n", - ): - """Initialize a Regression PRM that works with a huggingface backend. Currently supports and tested with IBM Process Reward Models. - - Args: - model_name_or_path (str): A local path to PRM or a huggingface PRM - score_token (str): token who's logits correspond to the PRM score. Usually is a step demarker (for non-generative PRMs) - device (str): pointer to the device on which to run the model - step_separator (str): string on which to separate the input content into steps - """ - super().__init__(model_name_or_path, score_token, device) - - # initialize PRM head - self.prm_head = torch.nn.Linear( - self.model.config.hidden_size, 2, bias=False, dtype=self.model.dtype - ).to(self.model.device) - - state = torch.load(model_name_or_path + "/added_params.bin") - # need to do this-- we save model dict as `prm_head.weight` during training - new_state_dict = {} - for k, v in state.items(): - new_k = k.replace("prm_head.", "") - new_state_dict[new_k] = v - - self.prm_head.load_state_dict(new_state_dict) - self.prm_head.eval() - - self.step_separator = step_separator - self.softmax = torch.nn.Softmax(dim=-1) - - def score(self, query: str, response: str) -> tuple[list[float], list[list[float]]]: - """Returns a final and per-step score for a given input query and response. - - Args: - query (str): User query - response (str): Assistant Response to score - """ - list_of_steps = self.stepify(response, self.step_separator) - # tokenizes the batch and concatenates the list of steps into a single step-separated response - batch = self.prepare_inputs(query, list_of_steps) - # move each item of the batch to the device - for i in batch: - batch[i] = batch[i].to(self.model.device) - - with torch.no_grad(): - model_outputs = self.model(**batch, output_hidden_states=True) - # all logits - all_prm_logits = self.prm_head(model_outputs["hidden_states"][-1]).squeeze( - -1 - ) - - # get logits for each end of step i.e. logits for step_eos positions in the input - prm_probs = [] - rewards = [] - for idx in range(all_prm_logits.shape[0]): - prm_indices = torch.where(batch["input_ids"][idx] == self._score_token_id)[ - 0 - ] - assert prm_indices.shape[0] > 0 - # head produces two logits, the second one is the logit for the correct answer - # convert logits to probabilities using softmax - # return list of floats instead of list of tensors - prm_probs_per_sample = [ - t.item() for t in self.softmax(all_prm_logits[idx][prm_indices])[:, 1] - ] - prm_probs.append(prm_probs_per_sample) - - reward = sum(prm_probs_per_sample) / len(prm_probs_per_sample) - rewards.append(reward) - - return rewards, prm_probs - - def prepare_inputs(self, user_content: str, steps: list[str]) -> BatchEncoding: - """Prepare the inputs for inference with the model. - - Args: - user_content (str): the user query - steps (List(str)): assistant response, broken down into steps - """ - text_with_steps_marked = "" - - for step in steps: - text_with_steps_marked += f"{step} {self._score_token}" - - message = [ - {"role": "user", "content": user_content}, - {"role": "assistant", "content": text_with_steps_marked}, - ] - input_message = self.tokenizer.apply_chat_template(message, tokenize=False) - - return self.tokenizer( - [input_message], return_tensors="pt", padding=True, truncation=True - ) diff --git a/mellea/backends/tools.py b/mellea/backends/tools.py index 272860af..93b02fd7 100644 --- a/mellea/backends/tools.py +++ b/mellea/backends/tools.py @@ -1,13 +1,16 @@ """Utilities for dealing with tools.""" +import inspect import json -from collections.abc import Callable, Generator, Iterable, Mapping -from typing import Any +import re +from collections import defaultdict +from collections.abc import Callable, Generator, Iterable, Mapping, Sequence +from typing import Any, Literal -from ollama._utils import convert_function_to_tool +from pydantic import BaseModel, ConfigDict, Field -from mellea.backends.types import ModelOption -from mellea.stdlib.base import CBlock, Component, TemplateRepresentation +from ..core import CBlock, Component, TemplateRepresentation +from .model_options import ModelOption def add_tools_from_model_options( @@ -129,9 +132,8 @@ def find_func(d) -> tuple[str | None, Mapping | None]: return None, None -# NOTE: these extraction tools only work for json based outputs. def parse_tools(llm_response: str) -> list[tuple[str, Mapping]]: - """A simple parser that will scan a string for tools and attempt to extract them.""" + """A simple parser that will scan a string for tools and attempt to extract them; only works for json based outputs.""" processed = " ".join(llm_response.split()) tools = [] @@ -141,3 +143,226 @@ def parse_tools(llm_response: str) -> list[tuple[str, Mapping]]: tools.append((tool_name, tool_arguments)) return tools + + +# Below functions and classes extracted from Ollama Python SDK (v0.6.1) +# so that all backends don't need it installed. +# https://github.com/ollama/ollama-python/blob/60e7b2f9ce710eeb57ef2986c46ea612ae7516af/ollama/_types.py#L19-L101 +class SubscriptableBaseModel(BaseModel): + """Class imported from Ollama.""" + + def __getitem__(self, key: str) -> Any: + """Getitem. + + >>> msg = Message(role='user') + >>> msg['role'] + 'user' + >>> msg = Message(role='user') + >>> msg['nonexistent'] + Traceback (most recent call last): + KeyError: 'nonexistent' + """ + if key in self: + return getattr(self, key) + + raise KeyError(key) + + def __setitem__(self, key: str, value: Any) -> None: + """Setitem. + + >>> msg = Message(role='user') + >>> msg['role'] = 'assistant' + >>> msg['role'] + 'assistant' + >>> tool_call = Message.ToolCall(function=Message.ToolCall.Function(name='foo', arguments={})) + >>> msg = Message(role='user', content='hello') + >>> msg['tool_calls'] = [tool_call] + >>> msg['tool_calls'][0]['function']['name'] + 'foo' + """ + setattr(self, key, value) + + def __contains__(self, key: str) -> bool: + """Contains. + + >>> msg = Message(role='user') + >>> 'nonexistent' in msg + False + >>> 'role' in msg + True + >>> 'content' in msg + False + >>> msg.content = 'hello!' + >>> 'content' in msg + True + >>> msg = Message(role='user', content='hello!') + >>> 'content' in msg + True + >>> 'tool_calls' in msg + False + >>> msg['tool_calls'] = [] + >>> 'tool_calls' in msg + True + >>> msg['tool_calls'] = [Message.ToolCall(function=Message.ToolCall.Function(name='foo', arguments={}))] + >>> 'tool_calls' in msg + True + >>> msg['tool_calls'] = None + >>> 'tool_calls' in msg + True + >>> tool = Tool() + >>> 'type' in tool + True + """ + if key in self.model_fields_set: + return True + + if value := self.__class__.model_fields.get(key): + return value.default is not None + + return False + + def get(self, key: str, default: Any = None) -> Any: + """Get. + + >>> msg = Message(role='user') + >>> msg.get('role') + 'user' + >>> msg = Message(role='user') + >>> msg.get('nonexistent') + >>> msg = Message(role='user') + >>> msg.get('nonexistent', 'default') + 'default' + >>> msg = Message(role='user', tool_calls=[ Message.ToolCall(function=Message.ToolCall.Function(name='foo', arguments={}))]) + >>> msg.get('tool_calls')[0]['function']['name'] + 'foo' + """ + return getattr(self, key) if hasattr(self, key) else default + + +# https://github.com/ollama/ollama-python/blob/60e7b2f9ce710eeb57ef2986c46ea612ae7516af/ollama/_types.py#L337-L363 +class Tool(SubscriptableBaseModel): + """Class imported from Ollama.""" + + type: str | None = "function" + + class Function(SubscriptableBaseModel): + """Class imported from Ollama.""" + + name: str | None = None + description: str | None = None + + class Parameters(SubscriptableBaseModel): + """Class imported from Ollama.""" + + model_config = ConfigDict(populate_by_name=True) + type: Literal["object"] | None = "object" + defs: Any | None = Field(None, alias="$defs") + items: Any | None = None + required: Sequence[str] | None = None + + class Property(SubscriptableBaseModel): + """Class imported from Ollama.""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + type: str | Sequence[str] | None = None + items: Any | None = None + description: str | None = None + enum: Sequence[Any] | None = None + + properties: Mapping[str, Property] | None = None + + parameters: Parameters | None = None + + function: Function | None = None + + +# https://github.com/ollama/ollama-python/blob/main/ollama/_utils.py#L13-L53 +def _parse_docstring(doc_string: str | None) -> dict[str, str]: + """Imported from Ollama.""" + parsed_docstring: defaultdict[str, str] = defaultdict(str) + if not doc_string: + return parsed_docstring + + key = str(hash(doc_string)) + for line in doc_string.splitlines(): + lowered_line = line.lower().strip() + if lowered_line.startswith("args:"): + key = "args" + elif lowered_line.startswith(("returns:", "yields:", "raises:")): + key = "_" + + else: + # maybe change to a list and join later + parsed_docstring[key] += f"{line.strip()}\n" + + last_key = None + for line in parsed_docstring["args"].splitlines(): + line = line.strip() + if ":" in line: + # Split the line on either: + # 1. A parenthetical expression like (integer) - captured in group 1 + # 2. A colon : + # Followed by optional whitespace. Only split on first occurrence. + parts = re.split(r"(?:\(([^)]*)\)|:)\s*", line, maxsplit=1) + + arg_name = parts[0].strip() + last_key = arg_name + + # Get the description - will be in parts[1] if parenthetical or parts[-1] if after colon + arg_description = parts[-1].strip() + if len(parts) > 2 and parts[1]: # Has parenthetical content + arg_description = parts[-1].split(":", 1)[-1].strip() + + parsed_docstring[last_key] = arg_description + + elif last_key and line: + parsed_docstring[last_key] += " " + line + + return parsed_docstring + + +# https://github.com/ollama/ollama-python/blob/60e7b2f9ce710eeb57ef2986c46ea612ae7516af/ollama/_utils.py#L56-L90 +def convert_function_to_tool(func: Callable) -> Tool: + """Imported from Ollama.""" + doc_string_hash = str(hash(inspect.getdoc(func))) + parsed_docstring = _parse_docstring(inspect.getdoc(func)) + schema = type( + func.__name__, + (BaseModel,), + { + "__annotations__": { + k: v.annotation if v.annotation != inspect._empty else str + for k, v in inspect.signature(func).parameters.items() + }, + "__signature__": inspect.signature(func), + "__doc__": parsed_docstring[doc_string_hash], + }, + ).model_json_schema() # type: ignore + + for k, v in schema.get("properties", {}).items(): + # If type is missing, the default is string + types = ( + {t.get("type", "string") for t in v.get("anyOf")} + if "anyOf" in v + else {v.get("type", "string")} + ) + if "null" in types: + schema["required"].remove(k) + types.discard("null") + + schema["properties"][k] = { + "description": parsed_docstring[k], + "type": ", ".join(types), + } + + tool = Tool( + type="function", + function=Tool.Function( + name=func.__name__, + description=schema.get("description", ""), + parameters=Tool.Function.Parameters(**schema), + ), + ) + + return Tool.model_validate(tool) diff --git a/mellea/backends/_utils.py b/mellea/backends/utils.py similarity index 84% rename from mellea/backends/_utils.py rename to mellea/backends/utils.py index 28dc6d5f..7ecca2ff 100644 --- a/mellea/backends/_utils.py +++ b/mellea/backends/utils.py @@ -1,22 +1,14 @@ +"""Utilities for Backends.""" + from __future__ import annotations import inspect -import itertools from collections.abc import Callable -from typing import Any, Literal - -from mellea.backends.formatter import Formatter -from mellea.backends.tools import parse_tools -from mellea.helpers.fancy_logger import FancyLogger -from mellea.stdlib.base import ( - CBlock, - Component, - Context, - ModelOutputThunk, - ModelToolCall, -) -from mellea.stdlib.chat import Message -from mellea.stdlib.requirement import ALoraRequirement, LLMaJRequirement, Requirement + +from ..core import CBlock, Component, Context, FancyLogger, ModelToolCall +from ..formatters import ChatFormatter +from ..stdlib.components import Message +from .tools import parse_tools # Chat = dict[Literal["role", "content"], str] # external apply_chat_template type hint is weaker # Chat = dict[str, str | list[dict[str, Any]] ] # for multi-modal models @@ -26,7 +18,7 @@ def to_chat( action: Component | CBlock, ctx: Context, - formatter: Formatter, + formatter: ChatFormatter, system_prompt: str | None, ) -> list[Chat]: """Converts a context and an action into a series of dicts to be passed to apply_chat_template . diff --git a/mellea/backends/vllm.py b/mellea/backends/vllm.py index 56d85b50..c2fac0ee 100644 --- a/mellea/backends/vllm.py +++ b/mellea/backends/vllm.py @@ -5,51 +5,47 @@ from __future__ import annotations -import abc import asyncio import dataclasses import datetime import functools import importlib -import inspect import json import os import shutil from collections.abc import Callable, Sequence -from typing import TYPE_CHECKING, Any, Optional, overload +from typing import Any, overload -import msgspec # type:ignore +import msgspec import outlines import outlines_core import torch -import vllm # type:ignore -from transformers import AutoTokenizer, PreTrainedTokenizerBase - -from mellea.backends import BaseModelSubclass -from mellea.backends._utils import to_chat, to_tool_calls -from mellea.backends.formatter import Formatter, FormatterBackend, TemplateFormatter -from mellea.backends.model_ids import ModelIdentifier -from mellea.backends.tools import ( - add_tools_from_context_actions, - add_tools_from_model_options, - convert_tools_to_json, -) -from mellea.backends.types import ModelOption -from mellea.helpers.async_helpers import get_current_event_loop, send_to_queue -from mellea.helpers.event_loop_helper import _run_async_in_thread -from mellea.helpers.fancy_logger import FancyLogger -from mellea.stdlib.base import ( +import vllm +from transformers import AutoTokenizer +from transformers.tokenization_utils_base import PreTrainedTokenizerBase + +from ..backends import ModelIdentifier +from ..core import ( + BaseModelSubclass, C, CBlock, Component, Context, + FancyLogger, GenerateLog, GenerateType, ModelOutputThunk, - TemplateRepresentation, ) -from mellea.stdlib.chat import Message -from mellea.stdlib.requirement import LLMaJRequirement, Requirement +from ..formatters import ChatFormatter, TemplateFormatter +from ..helpers import get_current_event_loop, send_to_queue +from .backend import FormatterBackend +from .model_options import ModelOption +from .tools import ( + add_tools_from_context_actions, + add_tools_from_model_options, + convert_tools_to_json, +) +from .utils import to_chat, to_tool_calls assert outlines, "outlines needs to be present to make outlines_core work" @@ -71,7 +67,7 @@ class LocalVLLMBackend(FormatterBackend): def __init__( self, model_id: str | ModelIdentifier, - formatter: Formatter | None = None, + formatter: ChatFormatter | None = None, *, model_options: dict | None = None, ): @@ -156,7 +152,7 @@ def __init__( vllm.AsyncEngineArgs(model=self._hf_model_id, **engine_args) ) break - except torch._dynamo.exc.BackendCompilerFailed as e: + except torch._dynamo.exc.BackendCompilerFailed as e: # type: ignore # example: # torch._dynamo.exc.BackendCompilerFailed: backend='' raised: # RuntimeError: vLLM failed to compile the model. The most likely reason for this is that a previous compilation failed, leading to a corrupted compilation artifact. We recommend trying to remove ~/.cache/vllm/torch_compile_cache and try again to see the real issue. @@ -313,10 +309,10 @@ async def _generate_from_context_standard( ), output_kind=( # returns results incrementally - vllm.sampling_params.RequestOutputKind.DELTA + vllm.sampling_params.RequestOutputKind.DELTA # type: ignore if model_options.get(ModelOption.STREAM, False) # returns only the final result - else vllm.sampling_params.RequestOutputKind.FINAL_ONLY + else vllm.sampling_params.RequestOutputKind.FINAL_ONLY # type: ignore ), ) @@ -329,7 +325,7 @@ async def _generate_from_context_standard( schema_json # type: ignore ) # type: ignore - from outlines.processors import RegexLogitsProcessor + from outlines.processors import RegexLogitsProcessor # type: ignore logits_processor = RegexLogitsProcessor( regex_str, @@ -475,7 +471,7 @@ async def generate_from_raw( **self._make_backend_specific_and_remove( model_options, vllm.SamplingParams ), - output_kind=vllm.sampling_params.RequestOutputKind.FINAL_ONLY, # returns only the final results + output_kind=vllm.sampling_params.RequestOutputKind.FINAL_ONLY, # returns only the final results # type: ignore ) if format is not None: @@ -485,7 +481,7 @@ async def generate_from_raw( schema_json # type: ignore ) # type: ignore - from outlines.processors import RegexLogitsProcessor + from outlines.processors import RegexLogitsProcessor # type: ignore logits_processor = RegexLogitsProcessor( regex_str, diff --git a/mellea/backends/watsonx.py b/mellea/backends/watsonx.py index 73f257e5..58004c4c 100644 --- a/mellea/backends/watsonx.py +++ b/mellea/backends/watsonx.py @@ -14,37 +14,36 @@ from ibm_watsonx_ai.foundation_models import ModelInference from ibm_watsonx_ai.foundation_models.schema import TextChatParameters -from mellea.backends import BaseModelSubclass, model_ids -from mellea.backends.formatter import Formatter, FormatterBackend, TemplateFormatter -from mellea.backends.model_ids import ModelIdentifier -from mellea.backends.tools import ( - add_tools_from_context_actions, - add_tools_from_model_options, - convert_tools_to_json, -) -from mellea.backends.types import ModelOption -from mellea.helpers.async_helpers import ( - ClientCache, - get_current_event_loop, - send_to_queue, -) -from mellea.helpers.fancy_logger import FancyLogger -from mellea.helpers.openai_compatible_helpers import ( - chat_completion_delta_merge, - extract_model_tool_requests, -) -from mellea.stdlib.base import ( +from ..backends import ModelIdentifier, model_ids +from ..core import ( + BaseModelSubclass, C, CBlock, Component, Context, + FancyLogger, GenerateLog, GenerateType, ModelOutputThunk, ModelToolCall, ) -from mellea.stdlib.chat import Message -from mellea.stdlib.requirement import ALoraRequirement # type: ignore +from ..formatters import ChatFormatter, TemplateFormatter +from ..helpers import ( + ClientCache, + chat_completion_delta_merge, + extract_model_tool_requests, + get_current_event_loop, + send_to_queue, +) +from ..stdlib.components import Message +from ..stdlib.requirements import ALoraRequirement +from .backend import FormatterBackend +from .model_options import ModelOption +from .tools import ( + add_tools_from_context_actions, + add_tools_from_model_options, + convert_tools_to_json, +) format: None = None # typing this variable in order to shadow the global format function and ensure mypy checks for errors @@ -55,7 +54,7 @@ class WatsonxAIBackend(FormatterBackend): def __init__( self, model_id: str | ModelIdentifier = model_ids.IBM_GRANITE_3_3_8B, - formatter: Formatter | None = None, + formatter: ChatFormatter | None = None, base_url: str | None = None, model_options: dict | None = None, *, diff --git a/mellea/core/__init__.py b/mellea/core/__init__.py new file mode 100644 index 00000000..0dbf9c55 --- /dev/null +++ b/mellea/core/__init__.py @@ -0,0 +1,23 @@ +"""Core Library for Mellea Interfaces.""" + +from .backend import Backend, BaseModelSubclass, generate_walk +from .base import ( + C, + CBlock, + Component, + ComponentParseError, + Context, + ContextTurn, + GenerateLog, + GenerateType, + ImageBlock, + ModelOutputThunk, + ModelToolCall, + S, + TemplateRepresentation, + blockify, +) +from .formatter import Formatter +from .requirement import Requirement, ValidationResult, default_output_to_bool +from .sampling import SamplingResult, SamplingStrategy +from .utils import FancyLogger diff --git a/mellea/core/backend.py b/mellea/core/backend.py new file mode 100644 index 00000000..8f505b00 --- /dev/null +++ b/mellea/core/backend.py @@ -0,0 +1,136 @@ +"""Interfaces for Backends and Generation.""" + +import abc +import asyncio +import itertools +from collections.abc import Sequence +from typing import overload + +import pydantic +import typing_extensions + +from .base import C, CBlock, Component, Context, ModelOutputThunk +from .utils import FancyLogger + +# Necessary to define a type that supports `None` so that the BaseModelSubclass +# can have a default value. Otherwise, Python complains about typed-components +# since types with default values must come after those without default values in +# function signatures (which is incompatible with our function parameter formatting). +pydantic_model_or_none = pydantic.BaseModel | None +BaseModelSubclass = typing_extensions.TypeVar( + "BaseModelSubclass", bound=pydantic_model_or_none, default=None +) # must be a subclass of BaseModel + + +class Backend(abc.ABC): + """An abstract `Backend`.""" + + @abc.abstractmethod + async def generate_from_context( + self, + action: Component[C] | CBlock, + ctx: Context, + *, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, + ) -> tuple[ModelOutputThunk[C], Context]: + """Generates a model output from a context. May not mutate the context. This must be called from a running event loop as it creates a task to run the generation request. + + Args: + action: The last item of the context should be passed in as an `action` instead of as part of the `ctx`. See `docs/dev/generate_signature_decisions.md`. + ctx: The rest of the context. + format: A response format to used for structured outputs / constrained decoding. + model_options: Any model options to upsert into the defaults for this call. + tool_calls: If `True`, then tool calls are extracts from the `action` `Component`. Assumption: if tool_calls is enabled, then the action `Component` has a TemplateRepresentation + + Returns: + a tuple of (ModelOutputThunk, Context) where the Context is the new context after the generation has been completed. + """ + ... + + @overload + async def generate_from_raw( + self, + actions: list[Component[C]], + ctx: Context, + *, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, + ) -> list[ModelOutputThunk[C]]: ... + + @overload + async def generate_from_raw( + self, + actions: list[Component[C] | CBlock], + ctx: Context, + *, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, + ) -> list[ModelOutputThunk[C | str]]: ... + + @abc.abstractmethod + async def generate_from_raw( + self, + actions: Sequence[Component[C] | CBlock], + ctx: Context, + *, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, + ) -> list[ModelOutputThunk]: + """Generates a model output from the provided input. Does not use context or templates. + + Args: + actions: list of actions to generate responses for. Each action is separate. + ctx: context passed to generation. Currently not used in generate_from_raw + format: A response format to used for structured outputs / constrained decoding. Note: some backends do not support this parameter. They will log warnings and continue to generate. + model_options: Any model options to upsert into the defaults for this call. + tool_calls: Always set to false unless supported by backend. + """ + + async def do_generate_walk( + self, action: CBlock | Component | ModelOutputThunk + ) -> None: + """Does the generation walk.""" + _to_compute = list(generate_walk(action)) + coroutines = [x.avalue() for x in _to_compute] + # The following log message might get noisy. Feel free to remove if so. + if len(_to_compute) > 0: + FancyLogger.get_logger().info( + f"generate_from_chat_context awaited on {len(_to_compute)} uncomputed mots." + ) + await asyncio.gather(*coroutines) + + async def do_generate_walks( + self, actions: list[CBlock | Component | ModelOutputThunk] + ) -> None: + """Does the generation walk.""" + _to_compute = [] + for action in actions: + _to_compute.extend(list(generate_walk(action))) + coroutines = [x.avalue() for x in _to_compute] + # The following log message might get noisy. Feel free to remove if so. + if len(_to_compute) > 0: + FancyLogger.get_logger().info( + f"generate_from_chat_context awaited on {len(_to_compute)} uncomputed mots." + ) + await asyncio.gather(*coroutines) + + +def generate_walk(c: CBlock | Component | ModelOutputThunk) -> list[ModelOutputThunk]: + """Returns the generation walk ordering for a Span.""" + match c: + case ModelOutputThunk() if not c.is_computed(): + return [c] + case CBlock(): + return [] + case Component(): + parts_walk = [generate_walk(p) for p in c.parts()] + return list(itertools.chain.from_iterable(parts_walk)) # aka flatten + case _: + raise ValueError( + f"parts should only contain CBlocks, Components, or ModelOutputThunks; found `{c!s:.10}{'...' if len(str(c)) > 10 else ''}` (type: {type(c)})" + ) diff --git a/mellea/stdlib/base.py b/mellea/core/base.py similarity index 84% rename from mellea/stdlib/base.py rename to mellea/core/base.py index 9c375e99..41179bfe 100644 --- a/mellea/stdlib/base.py +++ b/mellea/core/base.py @@ -1,4 +1,4 @@ -"""Basic stdlib data structures.""" +"""Core Classes and Data Structures.""" from __future__ import annotations @@ -17,19 +17,7 @@ import typing_extensions from PIL import Image as PILImage -S = typing_extensions.TypeVar("S", default=Any, covariant=True) -"""Used for class definitions for Component and ModelOutputThunk; also used for functions that don't accept CBlocks. Defaults to `Any`.""" - -C = typing_extensions.TypeVar("C", default=str) -"""Used for component typing in function parameters where the function takes a Component[C] and/or CBlock and can return a ModelOutputThunk[C]. Defaults to `str`.""" - -class ComponentParseError(Exception): - """Raised by `Component.parse()` when the underlying parsing method throws an exception.""" - - -# For ModelOutputThunk return types to be typed correctly, CBlocks must be defined -# using generics and a type var that defaults to str. CBlocks should never be initialized with [type]. class CBlock: """A `CBlock` is a block of content that can serve as input to or output from an LLM.""" @@ -138,6 +126,17 @@ def __repr__(self): return f"ImageBlock({self._value}, {self._meta.__repr__()})" +S = typing_extensions.TypeVar("S", default=Any, covariant=True) +"""Used for class definitions for Component and ModelOutputThunk; also used for functions that don't accept CBlocks. Defaults to `Any`.""" + +C = typing_extensions.TypeVar("C", default=str) +"""Used for component typing in function parameters where the function takes a Component[C] and/or CBlock and can return a ModelOutputThunk[C]. Defaults to `str`.""" + + +class ComponentParseError(Exception): + """Raised by `Component.parse()` when the underlying parsing method throws an exception.""" + + @runtime_checkable class Component(Protocol, Generic[S]): """A `Component` is a composite data structure that is intended to be represented to an LLM.""" @@ -168,58 +167,6 @@ def _parse(self, computed: ModelOutputThunk) -> S: raise NotImplementedError("parse isn't implemented by default") -def get_images_from_component(c: Component) -> None | list[ImageBlock]: - """Gets images from a `Component` if they are present and a non-empty list, otherwise returns None.""" - if hasattr(c, "images"): - imgs = c.images # type: ignore - if imgs is not None: - assert isinstance(imgs, list), "images field must be a list." - assert all(isinstance(im, ImageBlock) for im in imgs), ( - "all elements of images list must be ImageBlocks." - ) - if len(imgs) == 0: - return None - else: - return imgs - else: - return None - else: - return None - - -# TODO: Add support for passing in docs as model options. -class Document(Component[str]): - """Documents should typically be used in a Message object.""" - - def __init__(self, text: str, title: str | None = None, doc_id: str | None = None): - """Create a document object. Should typically be used as a list in the `_docs` field of Message.""" - self.text = text - self.title = title - self.doc_id = doc_id - - def parts(self) -> list[Component | CBlock]: - """The set of all the constituent parts of the `Component`.""" - raise NotImplementedError("parts isn't implemented by default") - - def format_for_llm(self) -> str: - """Formats the `Document` into a string. - - Returns: a string - """ - doc = "" - if self.doc_id is not None: - doc += f"document ID '{self.doc_id}': " - if self.title is not None: - doc += f"'{self.title}': " - doc += f"{self.text}" - - return doc - - def _parse(self, computed: ModelOutputThunk) -> str: - """Parse the model output. Returns string value for now.""" - return computed.value if computed.value is not None else "" - - class GenerateType(enum.Enum): """Used to track what functions can be used to extract a value from a ModelOutputThunk.""" @@ -461,20 +408,6 @@ def __deepcopy__(self, memo): return deepcopied -def blockify(s: str | CBlock | Component) -> CBlock | Component: - """`blockify` is a helper function that turns raw strings into CBlocks.""" - # noinspection PyUnreachableCode - match s: - case str(): - return CBlock(s) - case CBlock(): - return s - case Component(): - return s - case _: - raise Exception("Type Error") - - @dataclass class ContextTurn: """A turn of model input and model output.""" @@ -634,37 +567,6 @@ def view_for_generation(self) -> list[Component | CBlock] | None: ... -class ChatContext(Context): - """Initializes a chat context with unbounded window_size and is_chat=True by default.""" - - def __init__(self, *, window_size: int | None = None): - """Constructs a new chat context.""" - super().__init__() - self._window_size = window_size - - def add(self, c: Component | CBlock) -> ChatContext: - """Add a new component/cblock to the context. Returns the new context.""" - new = ChatContext.from_previous(self, c) - new._window_size = self._window_size - return new - - def view_for_generation(self) -> list[Component | CBlock] | None: - """Returns the context in a linearized form. Uses the window_size set during initialization.""" - return self.as_list(self._window_size) - - -class SimpleContext(Context): - """A `SimpleContext` is a context in which each interaction is a separate and independent turn. The history of all previous turns is NOT saved..""" - - def add(self, c: Component | CBlock) -> SimpleContext: - """Add a new component/cblock to the context. Returns the new context.""" - return SimpleContext.from_previous(self, c) - - def view_for_generation(self) -> list[Component | CBlock] | None: - """Returns an empty list.""" - return [] - - @dataclass class TemplateRepresentation: """Representing a component as a set of important attributes that can be consumed by the formatter.""" @@ -718,55 +620,34 @@ def call_func(self) -> Any: return self.func(**self.args) -class SimpleComponent(Component[str]): - """A Component that is make up of named spans.""" - - def __init__(self, **kwargs): - """Initialized a simple component of the constructor's kwargs.""" - for key in kwargs.keys(): - if type(kwargs[key]) is str: - kwargs[key] = CBlock(value=kwargs[key]) - self._kwargs_type_check(kwargs) - self._kwargs = kwargs - - def parts(self): - """Returns the values of the kwargs.""" - return list(self._kwargs.values()) - - def _kwargs_type_check(self, kwargs): - for key in kwargs.keys(): - value = kwargs[key] - assert issubclass(type(value), Component) or issubclass( - type(value), CBlock - ), f"Expected span but found {type(value)} of value: {value}" - assert type(key) is str - return True - - @staticmethod - def make_simple_string(kwargs): - """Uses <|key|>value to represent a simple component.""" - return "\n".join( - [f"<|{key}|>{value}" for (key, value) in kwargs.items()] - ) - - @staticmethod - def make_json_string(kwargs): - """Uses json.""" - str_args = dict() - for key in kwargs.keys(): - match kwargs[key]: - case ModelOutputThunk() | CBlock(): - str_args[key] = kwargs[key].value - case Component(): - str_args[key] = kwargs[key].format_for_llm() - import json - - return json.dumps(str_args) +def blockify(s: str | CBlock | Component) -> CBlock | Component: + """`blockify` is a helper function that turns raw strings into CBlocks.""" + # noinspection PyUnreachableCode + match s: + case str(): + return CBlock(s) + case CBlock(): + return s + case Component(): + return s + case _: + raise Exception("Type Error") - def format_for_llm(self): - """Uses a string rep.""" - return SimpleComponent.make_json_string(self._kwargs) - def _parse(self, computed: ModelOutputThunk) -> str: - """Parse the model output. Returns string value for now.""" - return computed.value if computed.value is not None else "" +def get_images_from_component(c: Component) -> None | list[ImageBlock]: + """Gets images from a `Component` if they are present and a non-empty list, otherwise returns None.""" + if hasattr(c, "images"): + imgs = c.images # type: ignore + if imgs is not None: + assert isinstance(imgs, list), "images field must be a list." + assert all(isinstance(im, ImageBlock) for im in imgs), ( + "all elements of images list must be ImageBlocks." + ) + if len(imgs) == 0: + return None + else: + return imgs + else: + return None + else: + return None diff --git a/mellea/core/formatter.py b/mellea/core/formatter.py new file mode 100644 index 00000000..bed99a05 --- /dev/null +++ b/mellea/core/formatter.py @@ -0,0 +1,14 @@ +"""Interfaces for Formatters.""" + +import abc + +from .base import CBlock, Component + + +class Formatter(abc.ABC): + """A Formatter converts `Component`s into strings and parses `ModelOutputThunk`s into `Component`s (or `CBlock`s).""" + + @abc.abstractmethod + def print(self, c: Component | CBlock) -> str: + """Renders a component for input to a model.""" + ... diff --git a/mellea/core/requirement.py b/mellea/core/requirement.py new file mode 100644 index 00000000..9162c1fc --- /dev/null +++ b/mellea/core/requirement.py @@ -0,0 +1,170 @@ +"""Interface for Requirements.""" + +import re +from collections.abc import Callable +from copy import copy + +from .backend import Backend, BaseModelSubclass +from .base import CBlock, Component, Context, ModelOutputThunk, TemplateRepresentation + + +class ValidationResult: + """ValidationResults store the output of a Requirement's validation. They can be used to return additional info from validation functions, which is useful for sampling/repairing.""" + + def __init__( + self, + result: bool, + *, + reason: str | None = None, + score: float | None = None, + thunk: ModelOutputThunk | None = None, + context: Context | None = None, + ): + """The result of a requirement's validation. + + A ValidationResult's result field always contains a definitive pass/fail. The other fields can be used to communicate additional information about that result. + + Args: + result: a boolean that is true if the requirement passed + reason: a reason for the result + score: if your validator gives you a score back, you can add this as metadata + thunk: if your validator utilizes a backend to generate a response, the ModelOutputThunk returned from that request + context: if your validator utilizes a backend to generate a response, the context associated with that response + """ + self._result = result + self._reason = reason + self._score = score + self._thunk = thunk + self._context = context + + @property + def reason(self) -> str | None: + """Reason for the validation result.""" + return self._reason + + @property + def score(self) -> float | None: + """An optional score for the validation result.""" + return self._score + + @property + def thunk(self) -> ModelOutputThunk | None: + """The ModelOutputThunk associated with the validation func if an llm was used to generate the final result.""" + return self._thunk + + @property + def context(self) -> Context | None: + """The context associated with validation if a backend was used to generate the final result.""" + return self._context + + def as_bool(self) -> bool: + """Return a boolean value based on the result.""" + return self._result + + def __bool__(self) -> bool: + """Return a boolean value based on the result.""" + return self.as_bool() + + +def default_output_to_bool(x: CBlock | str) -> bool: + """Checks if a given output should be marked converted to `True`. + + Checks if the output is exactly equal to "yes" or "y" (case-insensitive). If not, it will also + check if any of the words in the output are "yes" (case-insensitive). + """ + output = str(x) + + if output.upper() == "YES" or output.upper() == "Y": + return True + + word_splits = re.split(r"\W+", output) + if "YES" in [word.upper() for word in word_splits]: + return True + + return False + + +class Requirement(Component[str]): + """Requirements are a special type of Component used as input to the Validate step in Instruct/Validate/Repair patterns.""" + + def __init__( + self, + description: str | None = None, + validation_fn: Callable[[Context], ValidationResult] | None = None, + *, + output_to_bool: Callable[[CBlock | str], bool] | None = default_output_to_bool, + check_only: bool = False, + ): + """A Requirement, interpreted over a Context. + + By default, requirements are validated by the model using LLM-as-a-Judge (or a `constraint` LoRA when available). However, you can also provide a `validate` function with arbitrary behavior. + + Args: + description: A natural-language description of the requirement. This will sometimes be included in `Instruction` prompts; if you do not want the requirement to be included in the prompt to avoid [Purple Elephant Effects](https://${PROJECT_URL}/llm-requirement-engineering-and-purple-elephants/) use check_only=True. + validation_fn: If provided, this function will be executed instead of using LLM-as-a-Judge. The `bool()` for the function's output defines whether the requirement passes. + output_to_bool: An `output_to_bool` may be provided so that the library can translate the LLM-as-a-judge or ALora output into a boolean value. If none is provided, we will look for 'yes' (case-insensitive) in the LLMaJ output. + check_only: If set, then `Instruction` will not include this requirement in its prompt. + """ + self.description = description + self.output_to_bool = output_to_bool + self.validation_fn = validation_fn + self.check_only = check_only + + # Used for validation. Do not manually populate. + self._output: str | None = None + + async def validate( + self, + backend: Backend, + ctx: Context, + *, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + ) -> ValidationResult: + """Chooses the appropriate validation strategy and applies that strategy.""" + if self.validation_fn is not None: + # Python validation strategy + return self.validation_fn(ctx) + else: + # LLMaJ validation strategy. This includes ALora because the backend generate call will appropriately dispatch. + assert self.output_to_bool is not None + last_output = ctx.last_output() + assert isinstance(last_output, ModelOutputThunk), ( + " Context has no appropriate last output" + ) + + # Create a copy of the requirement that holds the output + # and its template gets populated with the output correctly. + req_copy = copy(self) + req_copy._output = last_output.value + llm_as_a_judge_result, val_ctx = await backend.generate_from_context( + req_copy, ctx, format=format, model_options=model_options + ) + await llm_as_a_judge_result.avalue() + + return ValidationResult( + result=self.output_to_bool(llm_as_a_judge_result), + reason=llm_as_a_judge_result.value, + thunk=llm_as_a_judge_result, + context=val_ctx, + ) + + def parts(self): + """Returns all of the constituent parts of a Requirement.""" + return [] + + def format_for_llm(self) -> TemplateRepresentation | str: + """Some object protocol magic happens here with management of the output.""" + assert self._output is not None, ( + "Object protocol error: should never try to templatize a Requirement except inside of a validate call for that same requirement." + ) + return TemplateRepresentation( + obj=self, + args={"description": self.description, "output": self._output}, + tools=None, + template_order=["*", "Requirement"], + ) + + def _parse(self, computed: ModelOutputThunk) -> str: + """Parse the model output. Returns string value for now.""" + return computed.value if computed.value is not None else "" diff --git a/mellea/stdlib/sampling/types.py b/mellea/core/sampling.py similarity index 94% rename from mellea/stdlib/sampling/types.py rename to mellea/core/sampling.py index 69cd13c2..41b945e0 100644 --- a/mellea/stdlib/sampling/types.py +++ b/mellea/core/sampling.py @@ -1,11 +1,11 @@ -"""Base types for sampling.""" +"""Interfaces for Sampling Strategies.""" import abc -from typing import Generic, TypeVar +from typing import Generic -from mellea.backends import Backend, BaseModelSubclass -from mellea.stdlib.base import CBlock, Component, Context, ModelOutputThunk, S -from mellea.stdlib.requirement import Requirement, ValidationResult +from .backend import Backend, BaseModelSubclass +from .base import CBlock, Component, Context, ModelOutputThunk, S +from .requirement import Requirement, ValidationResult class SamplingResult(CBlock, Generic[S]): diff --git a/mellea/helpers/fancy_logger.py b/mellea/core/utils.py similarity index 98% rename from mellea/helpers/fancy_logger.py rename to mellea/core/utils.py index d92bc39e..edb16d68 100644 --- a/mellea/helpers/fancy_logger.py +++ b/mellea/core/utils.py @@ -1,4 +1,4 @@ -"""Hendrik's Fancy Logger.""" +"""Utils for Core Library.""" import json import logging @@ -38,7 +38,7 @@ def emit(self, record): class JsonFormatter(logging.Formatter): """Logging formatter for JSON.""" - def format(self, record): + def format(self, record): # type: ignore """Formats record as a JSON serializable object.""" log_record = { "timestamp": self.formatTime(record, self.datefmt), diff --git a/mellea/formatters/__init__.py b/mellea/formatters/__init__.py new file mode 100644 index 00000000..6fd5e6bb --- /dev/null +++ b/mellea/formatters/__init__.py @@ -0,0 +1,6 @@ +"""Formatters.""" + +# Import from core for ergonomics. +from ..core import Formatter +from .chat_formatter import ChatFormatter +from .template_formatter import TemplateFormatter diff --git a/mellea/formatters/chat_formatter.py b/mellea/formatters/chat_formatter.py new file mode 100644 index 00000000..4f084881 --- /dev/null +++ b/mellea/formatters/chat_formatter.py @@ -0,0 +1,56 @@ +"""ChatFormatter.""" + +from ..core import ( + CBlock, + Component, + Formatter, + ModelOutputThunk, + TemplateRepresentation, +) +from ..stdlib.components.chat import Message + + +class ChatFormatter(Formatter): + """Formatter used by Legacy backends to format Contexts as Messages.""" + + def to_chat_messages(self, cs: list[Component | CBlock]) -> list[Message]: + """Helper method that converts a linearized chat history into a list of messages. The purpose of this helper is to prepare a sequence of Messages for input to a chat endpoint.""" + + def _to_msg(c: Component | CBlock) -> Message: + role: Message.Role = "user" # default to `user`; see ModelOutputThunk below for when the role changes. + + # Check if it's a ModelOutputThunk first since that changes what we should be printing + # as the message content. + if isinstance(c, ModelOutputThunk): + role = "assistant" # ModelOutputThunks should always be responses from a model. + + assert c.is_computed() + assert ( + c.value is not None + ) # This is already entailed by c.is_computed(); the line is included here to satisfy the type-checker. + + if c.parsed_repr is not None: + if isinstance(c.parsed_repr, Component): + # Only use the parsed_repr if it's something that we know how to print. + c = c.parsed_repr # This might be a message. + else: + # Otherwise, explicitly stringify it. + c = Message(role=role, content=str(c.parsed_repr)) + else: + c = Message(role=role, content=c.value) # type: ignore + + match c: + case Message(): + return c + case Component(): + images = None + tr = c.format_for_llm() + if isinstance(tr, TemplateRepresentation): + images = tr.images + + # components can have images + return Message(role=role, content=self.print(c), images=images) + case _: + return Message(role=role, content=self.print(c)) + + return [_to_msg(c) for c in cs] diff --git a/mellea/backends/formatter.py b/mellea/formatters/template_formatter.py similarity index 70% rename from mellea/backends/formatter.py rename to mellea/formatters/template_formatter.py index 0a3bcb35..c15e7ed7 100644 --- a/mellea/backends/formatter.py +++ b/mellea/formatters/template_formatter.py @@ -1,6 +1,5 @@ -"""Abstract interfaces for Formatters.""" +"""Template Formatter.""" -import abc import os import re import sys @@ -10,71 +9,13 @@ import jinja2 -from mellea.backends import Backend -from mellea.backends.cache import SimpleLRUCache -from mellea.backends.model_ids import ModelIdentifier -from mellea.helpers.fancy_logger import FancyLogger -from mellea.stdlib.base import ( - CBlock, - Component, - ModelOutputThunk, - TemplateRepresentation, -) -from mellea.stdlib.chat import Message +from ..backends.cache import SimpleLRUCache +from ..backends.model_ids import ModelIdentifier +from ..core import CBlock, Component, FancyLogger, TemplateRepresentation +from .chat_formatter import ChatFormatter -class Formatter(abc.ABC): - """A Formatter converts `Component`s into strings and parses `ModelOutputThunk`s into `Component`s (or `CBlock`s).""" - - @abc.abstractmethod - def print(self, c: Component | CBlock) -> str: - """Renders a component for input to a model.""" - ... - - def to_chat_messages(self, cs: list[Component | CBlock]) -> list[Message]: - """Helper method that converts a linearized chat history into a list of messages. The purpose of this helper is to prepare a sequence of Messages for input to a chat endpoint.""" - - def _to_msg(c: Component | CBlock) -> Message: - role: Message.Role = "user" # default to `user`; see ModelOutputThunk below for when the role changes. - - # Check if it's a ModelOutputThunk first since that changes what we should be printing - # as the message content. - if isinstance(c, ModelOutputThunk): - role = "assistant" # ModelOutputThunks should always be responses from a model. - - assert c.is_computed() - assert ( - c.value is not None - ) # This is already entailed by c.is_computed(); the line is included here to satisfy the type-checker. - - if c.parsed_repr is not None: - if isinstance(c.parsed_repr, Component): - # Only use the parsed_repr if it's something that we know how to print. - c = c.parsed_repr # This might be a message. - else: - # Otherwise, explicitly stringify it. - c = Message(role=role, content=str(c.parsed_repr)) - else: - c = Message(role=role, content=c.value) # type: ignore - - match c: - case Message(): - return c - case Component(): - images = None - tr = c.format_for_llm() - if isinstance(tr, TemplateRepresentation): - images = tr.images - - # components can have images - return Message(role=role, content=self.print(c), images=images) - case _: - return Message(role=role, content=self.print(c)) - - return [_to_msg(c) for c in cs] - - -class TemplateFormatter(Formatter, abc.ABC): +class TemplateFormatter(ChatFormatter): """Formatter that uses jinja2 templates.""" def __init__( @@ -342,34 +283,3 @@ def _get_package_name(module: str) -> str: package = "" return package - - -class FormatterBackend(Backend, abc.ABC): - """`FormatterBackend`s support legacy model types. - - The `mellea` library was designed to support generative computing with [spanned attention](https://generative.computing/what-are-spans.html) over [generative programming primitives](https://generative.computing/what-are-generative-programs.html). - In the ideal world, context management is handled via span scope-relations and all generative programming primitives are baked into the model via fine-tuning. - I.e., the model's instruction tuning is done in terms of generative programming primitives, and the model is then prompted with the same set of templates that were used for that tuning. - - Today, most models do not yet support spans and even those that do are not properly tuned to leverage generative programming primitives. - The `mellea` library supports these legacy models primarily through prompt engineering surfaced via `FormatterBackends`. - A `FormatterBackend` is a backend that uses hand-engineered prompts for rendering generative programming primitives to a model and parsing responses from the model back into `mellea`. - By default, a `FormatterBackend` uses jinja2 templates for pretty-printing, and relies on the user's ad-hoc logic for parsing. - """ - - def __init__( - self, - model_id: str | ModelIdentifier, - formatter: Formatter, - *, - model_options: dict | None = None, - ): - """Initializes a formatter-based backend for `model_id`. - - Args: - model_id (str): The model_id to use. - formatter (Formatter): The formatter to use for converting components into (fragments of) prompts. - model_options (Optional[dict]): The model options to use; if None, sensible defaults will be provided. - """ - super().__init__(model_id, model_options=model_options) - self.formatter: Formatter = formatter diff --git a/mellea/helpers/__init__.py b/mellea/helpers/__init__.py index 24c51755..d68986e6 100644 --- a/mellea/helpers/__init__.py +++ b/mellea/helpers/__init__.py @@ -1 +1,16 @@ -"""Various helpers and utilities.""" +"""Various Helpers and Utilities.""" + +from .async_helpers import ( + ClientCache, + get_current_event_loop, + send_to_queue, + wait_for_all_mots, +) +from .event_loop_helper import _run_async_in_thread +from .openai_compatible_helpers import ( + chat_completion_delta_merge, + extract_model_tool_requests, + message_to_openai_message, + messages_to_docs, +) +from .server_type import _server_type, _ServerType diff --git a/mellea/helpers/async_helpers.py b/mellea/helpers/async_helpers.py index 9a158489..c4ee2f4b 100644 --- a/mellea/helpers/async_helpers.py +++ b/mellea/helpers/async_helpers.py @@ -3,9 +3,9 @@ import asyncio from collections import OrderedDict from collections.abc import AsyncIterator, Coroutine -from typing import Any, TypeVar +from typing import Any -from mellea.stdlib.base import ModelOutputThunk +from ..core import ModelOutputThunk async def send_to_queue( diff --git a/mellea/helpers/event_loop_helper.py b/mellea/helpers/event_loop_helper.py index 2271b601..749be4d0 100644 --- a/mellea/helpers/event_loop_helper.py +++ b/mellea/helpers/event_loop_helper.py @@ -5,7 +5,7 @@ from collections.abc import Coroutine from typing import Any, TypeVar -from mellea.helpers.async_helpers import get_current_event_loop +from .async_helpers import get_current_event_loop R = TypeVar("R") @@ -19,7 +19,7 @@ def __init__(self): Do not instantiate this class. Rely on the exported `_run_async_in_thread` function. """ self._event_loop = asyncio.new_event_loop() - self._thread: threading.Thread = threading.Thread( + self._thread: threading.Thread = threading.Thread( # type: ignore[annotation-unchecked] target=self._event_loop.run_forever, daemon=True, # type: ignore ) diff --git a/mellea/helpers/openai_compatible_helpers.py b/mellea/helpers/openai_compatible_helpers.py index 38befda4..4462c1e8 100644 --- a/mellea/helpers/openai_compatible_helpers.py +++ b/mellea/helpers/openai_compatible_helpers.py @@ -4,8 +4,8 @@ from collections.abc import Callable from typing import Any -from mellea.helpers.fancy_logger import FancyLogger -from mellea.stdlib.base import ModelToolCall +from ..core import FancyLogger, ModelToolCall +from ..stdlib.components import Document, Message def extract_model_tool_requests( @@ -115,3 +115,53 @@ def chat_completion_delta_merge( current_tool["function"]["arguments"] += fx_info["arguments"] return merged + + +def message_to_openai_message(msg: Message): + """Serializes a mellea Message object to the message format required by OpenAI compatible api providers.""" + if msg.images is not None: + img_list = [ + {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img}"}} + for img in msg.images + ] + + return { + "role": msg.role, + "content": [{"type": "text", "text": msg.content}, *img_list], + } + else: + return {"role": msg.role, "content": msg.content} + # Target format: + # { + # "role": "user", + # "content": [ + # { + # "type": "text", + # "text": "What's in this picture?" + # }, + # { + # "type": "image_url", + # "image_url": { + # "url": "data:image/jpeg;base64," + # } + # } + # ] + # } + + +def messages_to_docs(msgs: list[Message]) -> list[dict[str, str]]: + """Extracts the docs from a list of messages.""" + docs: list[Document] = [] + for message in msgs: + if message._docs is not None: + docs.extend(message._docs) + + json_docs: list[dict[str, str]] = [] + for doc in docs: + json_doc = {"text": doc.text} + if doc.title is not None: + json_doc["title"] = doc.title + if doc.doc_id is not None: + json_doc["doc_id"] = doc.doc_id + json_docs.append(json_doc) + return json_docs diff --git a/mellea/helpers/server_type.py b/mellea/helpers/server_type.py new file mode 100644 index 00000000..46fe0ab9 --- /dev/null +++ b/mellea/helpers/server_type.py @@ -0,0 +1,28 @@ +"""Server Type Helpers.""" + +from enum import Enum +from urllib.parse import urlparse + + +class _ServerType(Enum): + """Different types of servers that might be relevant for a backend.""" + + UNKNOWN = 0 + LOCALHOST = 1 + OPENAI = 2 + REMOTE_VLLM = 3 + """Must be set manually for now.""" + + +def _server_type(url: str) -> _ServerType: + """Find a server type based on the url.""" + try: + parsed = urlparse(url) + hostname = parsed.hostname + if hostname in ("localhost", "127.0.0.1", "::1", "0.0.0.0"): + return _ServerType.LOCALHOST + elif hostname == "api.openai.com": + return _ServerType.OPENAI + except Exception as e: + print(f"Error parsing URL: {e}") + return _ServerType.UNKNOWN diff --git a/mellea/py.typed b/mellea/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/mellea/stdlib/__init__.py b/mellea/stdlib/__init__.py deleted file mode 100644 index 35fce96f..00000000 --- a/mellea/stdlib/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""The mellea standard library.""" diff --git a/mellea/stdlib/components/__init__.py b/mellea/stdlib/components/__init__.py new file mode 100644 index 00000000..0c958d57 --- /dev/null +++ b/mellea/stdlib/components/__init__.py @@ -0,0 +1,19 @@ +"""Module for Components.""" + +# Import from core for ergonomics. +from ...core import ( + CBlock, + Component, + ComponentParseError, + ImageBlock, + ModelOutputThunk, + TemplateRepresentation, + blockify, +) +from .chat import Message, ToolMessage, as_chat_history +from .docs.document import Document +from .instruction import Instruction +from .intrinsic import Intrinsic +from .mify import mify +from .mobject import MObject, MObjectProtocol, Query, Transform +from .simple import SimpleComponent diff --git a/mellea/stdlib/chat.py b/mellea/stdlib/components/chat.py similarity index 97% rename from mellea/stdlib/chat.py rename to mellea/stdlib/components/chat.py index f024089e..8763a70b 100644 --- a/mellea/stdlib/chat.py +++ b/mellea/stdlib/components/chat.py @@ -3,17 +3,16 @@ from collections.abc import Mapping from typing import Any, Literal -from mellea.helpers.fancy_logger import FancyLogger -from mellea.stdlib.base import ( +from ...core import ( CBlock, Component, Context, - Document, ImageBlock, ModelOutputThunk, ModelToolCall, TemplateRepresentation, ) +from .docs.document import Document class Message(Component["Message"]): @@ -58,9 +57,9 @@ def images(self) -> None | list[str]: return [str(i.value) for i in self._images_cblocks] return None - def parts(self): + def parts(self) -> list[Component | CBlock]: """Returns all of the constituent parts of an Instruction.""" - parts = [self._content_cblock] + parts: list[Component | CBlock] = [self._content_cblock] if self._docs is not None: parts.extend(self._docs) # TODO: we need to do this but images are not currently cblocks. This is captured in an issue on Jan 26 sprint. Leaving this code commented out for now. diff --git a/mellea/stdlib/components/docs/__init__.py b/mellea/stdlib/components/docs/__init__.py new file mode 100644 index 00000000..1c63bfb2 --- /dev/null +++ b/mellea/stdlib/components/docs/__init__.py @@ -0,0 +1,3 @@ +"""Classes and functions for working with document-like objects.""" + +from .richdocument import RichDocument, Table, TableQuery, TableTransform diff --git a/mellea/stdlib/components/docs/document.py b/mellea/stdlib/components/docs/document.py new file mode 100644 index 00000000..577a6639 --- /dev/null +++ b/mellea/stdlib/components/docs/document.py @@ -0,0 +1,36 @@ +"""Document component.""" + +from ....core import CBlock, Component, ModelOutputThunk + + +# TODO: Add support for passing in docs as model options. +class Document(Component[str]): + """Documents should typically be used in a Message object.""" + + def __init__(self, text: str, title: str | None = None, doc_id: str | None = None): + """Create a document object. Should typically be used as a list in the `_docs` field of Message.""" + self.text = text + self.title = title + self.doc_id = doc_id + + def parts(self) -> list[Component | CBlock]: + """The set of all the constituent parts of the `Component`.""" + raise NotImplementedError("parts isn't implemented by default") + + def format_for_llm(self) -> str: + """Formats the `Document` into a string. + + Returns: a string + """ + doc = "" + if self.doc_id is not None: + doc += f"document ID '{self.doc_id}': " + if self.title is not None: + doc += f"'{self.title}': " + doc += f"{self.text}" + + return doc + + def _parse(self, computed: ModelOutputThunk) -> str: + """Parse the model output. Returns string value for now.""" + return computed.value if computed.value is not None else "" diff --git a/mellea/stdlib/docs/richdocument.py b/mellea/stdlib/components/docs/richdocument.py similarity index 94% rename from mellea/stdlib/docs/richdocument.py rename to mellea/stdlib/components/docs/richdocument.py index 6913d71f..75bcb60c 100644 --- a/mellea/stdlib/docs/richdocument.py +++ b/mellea/stdlib/components/docs/richdocument.py @@ -11,13 +11,8 @@ from docling_core.types.doc.document import DoclingDocument, TableItem from docling_core.types.io import DocumentStream -from mellea.stdlib.base import ( - CBlock, - Component, - ModelOutputThunk, - TemplateRepresentation, -) -from mellea.stdlib.mobject import MObject, Query, Transform +from ....core import CBlock, Component, ModelOutputThunk, TemplateRepresentation +from ..mobject import MObject, Query, Transform class RichDocument(Component[str]): @@ -105,9 +100,10 @@ def __init__(self, obj: Table, query: str) -> None: """ super().__init__(obj, query) - def parts(self): + def parts(self) -> list[Component | CBlock]: """The list of cblocks/components on which TableQuery depends.""" - return [self._obj] + cs: list[Component | CBlock] = [self._obj] + return cs def format_for_llm(self) -> TemplateRepresentation: """Template arguments for Formatter.""" @@ -135,9 +131,10 @@ def __init__(self, obj: Table, transformation: str) -> None: """ super().__init__(obj, transformation) - def parts(self): + def parts(self) -> list[Component | CBlock]: """The parts for this component.""" - return [self._obj] + cs: list[Component | CBlock] = [self._obj] + return cs def format_for_llm(self) -> TemplateRepresentation: """Template arguments for Formatter.""" diff --git a/mellea/stdlib/genslot.py b/mellea/stdlib/components/genslot.py similarity index 98% rename from mellea/stdlib/genslot.py rename to mellea/stdlib/components/genslot.py index f56379ae..eff9ae75 100644 --- a/mellea/stdlib/genslot.py +++ b/mellea/stdlib/components/genslot.py @@ -4,25 +4,28 @@ import functools import inspect from collections.abc import Awaitable, Callable, Coroutine -from copy import copy, deepcopy +from copy import deepcopy from dataclasses import dataclass, fields from typing import Any, Generic, ParamSpec, TypedDict, TypeVar, get_type_hints, overload from pydantic import BaseModel, Field, create_model import mellea.stdlib.functional as mfuncs -from mellea.backends import Backend -from mellea.helpers.fancy_logger import FancyLogger -from mellea.stdlib.base import ( + +from ...core import ( + Backend, CBlock, Component, Context, + FancyLogger, ModelOutputThunk, + Requirement, + SamplingStrategy, TemplateRepresentation, + ValidationResult, ) -from mellea.stdlib.requirement import Requirement, ValidationResult, reqify -from mellea.stdlib.sampling.types import SamplingStrategy -from mellea.stdlib.session import MelleaSession +from ..requirements.requirement import reqify +from ..session import MelleaSession P = ParamSpec("P") R = TypeVar("R") @@ -382,8 +385,8 @@ def _context_backend_extract_args_and_kwargs( return extracted - def parts(self): - """Not implemented.""" + def parts(self) -> list[Component | CBlock]: + """Parts of Genslot.""" cs: list = [] if self._arguments is not None: cs.append(self._arguments) diff --git a/mellea/stdlib/instruction.py b/mellea/stdlib/components/instruction.py similarity index 97% rename from mellea/stdlib/instruction.py rename to mellea/stdlib/components/instruction.py index 7e9440ef..32a8a0dc 100644 --- a/mellea/stdlib/instruction.py +++ b/mellea/stdlib/components/instruction.py @@ -6,15 +6,16 @@ import jinja2 -from mellea.stdlib.base import ( +from ...core import ( CBlock, Component, ImageBlock, ModelOutputThunk, + Requirement, TemplateRepresentation, blockify, ) -from mellea.stdlib.requirement import Requirement, reqify +from ..requirements.requirement import reqify class Instruction(Component[str]): @@ -136,8 +137,9 @@ def parts(self): cs.extend(list(self._grounding_context.values())) cs.extend(self._requirements) cs.extend(self._icl_examples) - cs = list(filter(lambda x: x is not None, cs)) - return cs + + filtered: list[Component | CBlock] = list(filter(lambda x: x is not None, cs)) # type: ignore + return filtered def format_for_llm(self) -> TemplateRepresentation: """Formats the instruction for Formatter use.""" diff --git a/mellea/stdlib/components/intrinsic/__init__.py b/mellea/stdlib/components/intrinsic/__init__.py new file mode 100644 index 00000000..1700ceb7 --- /dev/null +++ b/mellea/stdlib/components/intrinsic/__init__.py @@ -0,0 +1,3 @@ +"""Module for working with intrinsics.""" + +from .intrinsic import Intrinsic diff --git a/mellea/stdlib/intrinsics/intrinsic.py b/mellea/stdlib/components/intrinsic/intrinsic.py similarity index 90% rename from mellea/stdlib/intrinsics/intrinsic.py rename to mellea/stdlib/components/intrinsic/intrinsic.py index 925dba6b..c12fa54f 100644 --- a/mellea/stdlib/intrinsics/intrinsic.py +++ b/mellea/stdlib/components/intrinsic/intrinsic.py @@ -1,16 +1,7 @@ """Module for Intrinsics.""" -import pathlib -from copy import copy -from typing import cast - -from mellea.backends.adapters.catalog import AdapterType, fetch_intrinsic_metadata -from mellea.stdlib.base import ( - CBlock, - Component, - ModelOutputThunk, - TemplateRepresentation, -) +from ....backends.adapters import AdapterType, fetch_intrinsic_metadata +from ....core import CBlock, Component, ModelOutputThunk, TemplateRepresentation class Intrinsic(Component[str]): diff --git a/mellea/stdlib/intrinsics/rag.py b/mellea/stdlib/components/intrinsic/rag.py similarity index 96% rename from mellea/stdlib/intrinsics/rag.py rename to mellea/stdlib/components/intrinsic/rag.py index 80f7a999..3d2c1ebd 100644 --- a/mellea/stdlib/intrinsics/rag.py +++ b/mellea/stdlib/components/intrinsic/rag.py @@ -3,16 +3,13 @@ import collections.abc import json -import mellea.stdlib.functional as mfuncs -from mellea.backends.adapters.adapter import ( - AdapterMixin, - AdapterType, - GraniteCommonAdapter, -) -from mellea.backends.types import ModelOption -from mellea.stdlib.base import ChatContext, Document -from mellea.stdlib.chat import Message -from mellea.stdlib.intrinsics.intrinsic import Intrinsic +from ....backends import ModelOption +from ....backends.adapters import AdapterMixin, AdapterType, GraniteCommonAdapter +from ....stdlib import functional as mfuncs +from ...components import Document +from ...context import ChatContext +from ..chat import Message +from .intrinsic import Intrinsic _ANSWER_RELEVANCE_CORRECTION_METHODS = { "Excessive unnecessary information": "removing the excessive information from the " diff --git a/mellea/stdlib/mify.py b/mellea/stdlib/components/mify.py similarity index 99% rename from mellea/stdlib/mify.py rename to mellea/stdlib/components/mify.py index 7cc69901..8bbfcb35 100644 --- a/mellea/stdlib/mify.py +++ b/mellea/stdlib/components/mify.py @@ -5,14 +5,14 @@ from collections.abc import Callable from typing import Any, Protocol, TypeVar, overload, runtime_checkable -from mellea.stdlib.base import ( +from ...core import ( CBlock, Component, ComponentParseError, ModelOutputThunk, TemplateRepresentation, ) -from mellea.stdlib.mobject import MObjectProtocol, Query, Transform +from .mobject import MObjectProtocol, Query, Transform @runtime_checkable @@ -313,7 +313,7 @@ def mify(*args, **kwargs): # noqa: D417 def _mify( *, - obj: T + obj: T # type: ignore | None = None, # Necessary if the decorator is called without args or directly on the class. query_type: type = Query, transform_type: type = Transform, diff --git a/mellea/stdlib/mobject.py b/mellea/stdlib/components/mobject.py similarity index 98% rename from mellea/stdlib/mobject.py rename to mellea/stdlib/components/mobject.py index 1fa2e83b..3ab48c04 100644 --- a/mellea/stdlib/mobject.py +++ b/mellea/stdlib/components/mobject.py @@ -6,12 +6,7 @@ from collections.abc import Callable from typing import Protocol, runtime_checkable -from mellea.stdlib.base import ( - CBlock, - Component, - ModelOutputThunk, - TemplateRepresentation, -) +from ...core import CBlock, Component, ModelOutputThunk, TemplateRepresentation class Query(Component[str]): diff --git a/mellea/stdlib/components/simple.py b/mellea/stdlib/components/simple.py new file mode 100644 index 00000000..2d0f7dcc --- /dev/null +++ b/mellea/stdlib/components/simple.py @@ -0,0 +1,57 @@ +"""SimpleComponent.""" + +from ...core import CBlock, Component, ModelOutputThunk + + +class SimpleComponent(Component[str]): + """A Component that is make up of named spans.""" + + def __init__(self, **kwargs): + """Initialized a simple component of the constructor's kwargs.""" + for key in kwargs.keys(): + if type(kwargs[key]) is str: + kwargs[key] = CBlock(value=kwargs[key]) + self._kwargs_type_check(kwargs) + self._kwargs = kwargs + + def parts(self): + """Returns the values of the kwargs.""" + return list(self._kwargs.values()) + + def _kwargs_type_check(self, kwargs): + for key in kwargs.keys(): + value = kwargs[key] + assert issubclass(type(value), Component) or issubclass( + type(value), CBlock + ), f"Expected span but found {type(value)} of value: {value}" + assert type(key) is str + return True + + @staticmethod + def make_simple_string(kwargs): + """Uses <|key|>value to represent a simple component.""" + return "\n".join( + [f"<|{key}|>{value}" for (key, value) in kwargs.items()] + ) + + @staticmethod + def make_json_string(kwargs): + """Uses json.""" + str_args = dict() + for key in kwargs.keys(): + match kwargs[key]: + case ModelOutputThunk() | CBlock(): + str_args[key] = kwargs[key].value + case Component(): + str_args[key] = kwargs[key].format_for_llm() + import json + + return json.dumps(str_args) + + def format_for_llm(self): + """Uses a string rep.""" + return SimpleComponent.make_json_string(self._kwargs) + + def _parse(self, computed: ModelOutputThunk) -> str: + """Parse the model output. Returns string value for now.""" + return computed.value if computed.value is not None else "" diff --git a/mellea/stdlib/test_based_eval.py b/mellea/stdlib/components/test_based_eval.py similarity index 97% rename from mellea/stdlib/test_based_eval.py rename to mellea/stdlib/components/test_based_eval.py index 56df2344..f7f4bf6e 100644 --- a/mellea/stdlib/test_based_eval.py +++ b/mellea/stdlib/components/test_based_eval.py @@ -6,12 +6,7 @@ from pydantic import BaseModel, Field, field_validator -from mellea.stdlib.base import ( - CBlock, - Component, - ModelOutputThunk, - TemplateRepresentation, -) +from ...core import CBlock, Component, ModelOutputThunk, TemplateRepresentation class Message(BaseModel): diff --git a/mellea/stdlib/context.py b/mellea/stdlib/context.py new file mode 100644 index 00000000..7f4b0da7 --- /dev/null +++ b/mellea/stdlib/context.py @@ -0,0 +1,37 @@ +"""Basic Contexts.""" + +from __future__ import annotations + +# Leave unused `ContextTurn` import for import ergonomics. +from ..core import CBlock, Component, Context, ContextTurn + + +class ChatContext(Context): + """Initializes a chat context with unbounded window_size and is_chat=True by default.""" + + def __init__(self, *, window_size: int | None = None): + """Constructs a new chat context.""" + super().__init__() + self._window_size = window_size + + def add(self, c: Component | CBlock) -> ChatContext: + """Add a new component/cblock to the context. Returns the new context.""" + new = ChatContext.from_previous(self, c) + new._window_size = self._window_size + return new + + def view_for_generation(self) -> list[Component | CBlock] | None: + """Returns the context in a linearized form. Uses the window_size set during initialization.""" + return self.as_list(self._window_size) + + +class SimpleContext(Context): + """A `SimpleContext` is a context in which each interaction is a separate and independent turn. The history of all previous turns is NOT saved..""" + + def add(self, c: Component | CBlock) -> SimpleContext: + """Add a new component/cblock to the context. Returns the new context.""" + return SimpleContext.from_previous(self, c) + + def view_for_generation(self) -> list[Component | CBlock] | None: + """Returns an empty list.""" + return [] diff --git a/mellea/stdlib/docs/__init__.py b/mellea/stdlib/docs/__init__.py deleted file mode 100644 index 0af4642e..00000000 --- a/mellea/stdlib/docs/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Classes and functions for working with document-like objects.""" diff --git a/mellea/stdlib/functional.py b/mellea/stdlib/functional.py index 950e103c..36fe1ca0 100644 --- a/mellea/stdlib/functional.py +++ b/mellea/stdlib/functional.py @@ -4,34 +4,31 @@ import asyncio from collections.abc import Coroutine -from typing import Any, Literal, TypeVar, overload +from typing import Any, Literal, overload from PIL import Image as PILImage -from mellea.backends import Backend, BaseModelSubclass -from mellea.backends.formatter import FormatterBackend -from mellea.helpers.event_loop_helper import _run_async_in_thread -from mellea.helpers.fancy_logger import FancyLogger -from mellea.stdlib.base import ( +from ..backends import FormatterBackend +from ..core import ( + Backend, + BaseModelSubclass, CBlock, Component, Context, + FancyLogger, GenerateLog, ImageBlock, ModelOutputThunk, + Requirement, S, - SimpleContext, -) -from mellea.stdlib.chat import Message, ToolMessage -from mellea.stdlib.instruction import Instruction -from mellea.stdlib.mify import mify -from mellea.stdlib.mobject import MObjectProtocol -from mellea.stdlib.requirement import Requirement, ValidationResult -from mellea.stdlib.sampling import ( - RejectionSamplingStrategy, SamplingResult, SamplingStrategy, + ValidationResult, ) +from ..helpers import _run_async_in_thread +from .components import Instruction, Message, MObjectProtocol, ToolMessage, mify +from .context import SimpleContext +from .sampling import RejectionSamplingStrategy @overload diff --git a/mellea/stdlib/requirement.py b/mellea/stdlib/requirement.py deleted file mode 100644 index 8ad261e8..00000000 --- a/mellea/stdlib/requirement.py +++ /dev/null @@ -1,397 +0,0 @@ -"""Requirements are a special type of Component used as input to the "validate" step in Instruct/Validate/Repair design patterns.""" - -import inspect -import json -import re -from collections.abc import Callable -from copy import copy -from typing import Any, overload - -from mellea.backends import Backend, BaseModelSubclass -from mellea.backends.adapters.adapter import AdapterType -from mellea.helpers.fancy_logger import FancyLogger -from mellea.stdlib.base import ( - CBlock, - Component, - Context, - ModelOutputThunk, - TemplateRepresentation, -) -from mellea.stdlib.intrinsics.intrinsic import Intrinsic - - -def default_output_to_bool(x: CBlock | str) -> bool: - """Checks if a given output should be marked converted to `True`. - - Checks if the output is exactly equal to "yes" or "y" (case-insensitive). If not, it will also - check if any of the words in the output are "yes" (case-insensitive). - """ - output = str(x) - - if output.upper() == "YES" or output.upper() == "Y": - return True - - word_splits = re.split(r"\W+", output) - if "YES" in [word.upper() for word in word_splits]: - return True - - return False - - -class ValidationResult: - """ValidationResults store the output of a Requirement's validation. They can be used to return additional info from validation functions, which is useful for sampling/repairing.""" - - def __init__( - self, - result: bool, - *, - reason: str | None = None, - score: float | None = None, - thunk: ModelOutputThunk | None = None, - context: Context | None = None, - ): - """The result of a requirement's validation. - - A ValidationResult's result field always contains a definitive pass/fail. The other fields can be used to communicate additional information about that result. - - Args: - result: a boolean that is true if the requirement passed - reason: a reason for the result - score: if your validator gives you a score back, you can add this as metadata - thunk: if your validator utilizes a backend to generate a response, the ModelOutputThunk returned from that request - context: if your validator utilizes a backend to generate a response, the context associated with that response - """ - self._result = result - self._reason = reason - self._score = score - self._thunk = thunk - self._context = context - - @property - def reason(self) -> str | None: - """Reason for the validation result.""" - return self._reason - - @property - def score(self) -> float | None: - """An optional score for the validation result.""" - return self._score - - @property - def thunk(self) -> ModelOutputThunk | None: - """The ModelOutputThunk associated with the validation func if an llm was used to generate the final result.""" - return self._thunk - - @property - def context(self) -> Context | None: - """The context associated with validation if a backend was used to generate the final result.""" - return self._context - - def as_bool(self) -> bool: - """Return a boolean value based on the result.""" - return self._result - - def __bool__(self) -> bool: - """Return a boolean value based on the result.""" - return self.as_bool() - - -class Requirement(Component[str]): - """Requirements are a special type of Component used as input to the Validate step in Instruct/Validate/Repair patterns.""" - - def __init__( - self, - description: str | None = None, - validation_fn: Callable[[Context], ValidationResult] | None = None, - *, - output_to_bool: Callable[[CBlock | str], bool] | None = default_output_to_bool, - check_only: bool = False, - ): - """A Requirement, interpreted over a Context. - - By default, requirements are validated by the model using LLM-as-a-Judge (or a `constraint` LoRA when available). However, you can also provide a `validate` function with arbitrary behavior. - - Args: - description: A natural-language description of the requirement. This will sometimes be included in `Instruction` prompts; if you do not want the requirement to be included in the prompt to avoid [Purple Elephant Effects](https://${PROJECT_URL}/llm-requirement-engineering-and-purple-elephants/) use check_only=True. - validation_fn: If provided, this function will be executed instead of using LLM-as-a-Judge. The `bool()` for the function's output defines whether the requirement passes. - output_to_bool: An `output_to_bool` may be provided so that the library can translate the LLM-as-a-judge or ALora output into a boolean value. If none is provided, we will look for 'yes' (case-insensitive) in the LLMaJ output. - check_only: If set, then `Instruction` will not include this requirement in its prompt. - """ - self.description = description - self.output_to_bool = output_to_bool - self.validation_fn = validation_fn - self.check_only = check_only - - # Used for validation. Do not manually populate. - self._output: str | None = None - - async def validate( - self, - backend: Backend, - ctx: Context, - *, - format: type[BaseModelSubclass] | None = None, - model_options: dict | None = None, - ) -> ValidationResult: - """Chooses the appropriate validation strategy and applies that strategy.""" - if self.validation_fn is not None: - # Python validation strategy - return self.validation_fn(ctx) - else: - # LLMaJ validation strategy. This includes ALora because the backend generate call will appropriately dispatch. - assert self.output_to_bool is not None - last_output = ctx.last_output() - assert isinstance(last_output, ModelOutputThunk), ( - " Context has no appropriate last output" - ) - - # Create a copy of the requirement that holds the output - # and its template gets populated with the output correctly. - req_copy = copy(self) - req_copy._output = last_output.value - llm_as_a_judge_result, val_ctx = await backend.generate_from_context( - req_copy, ctx, format=format, model_options=model_options - ) - await llm_as_a_judge_result.avalue() - - return ValidationResult( - result=self.output_to_bool(llm_as_a_judge_result), - reason=llm_as_a_judge_result.value, - thunk=llm_as_a_judge_result, - context=val_ctx, - ) - - def parts(self): - """Returns all of the constituent parts of a Requirement.""" - return [] - - def format_for_llm(self) -> TemplateRepresentation | str: - """Some object protocol magic happens here with management of the output.""" - assert self._output is not None, ( - "Object protocol error: should never try to templatize a Requirement except inside of a validate call for that same requirement." - ) - return TemplateRepresentation( - obj=self, - args={"description": self.description, "output": self._output}, - tools=None, - template_order=["*", "Requirement"], - ) - - def _parse(self, computed: ModelOutputThunk) -> str: - """Parse the model output. Returns string value for now.""" - return computed.value if computed.value is not None else "" - - -class LLMaJRequirement(Requirement): - """A requirement that always uses LLM-as-a-Judge. Any available constraint ALoRA will be ignored.""" - - use_aloras: bool = False - - -def requirement_check_to_bool(x: CBlock | str) -> bool: - """Checks if a given output should be marked converted to `True`. - - By default, the requirement check alora outputs: `{"requirement_likelihood": 0.0}`. - True if >.5 - """ - output = str(x) - req_dict: dict[str, Any] = json.loads(output) - - likelihood = req_dict.get("requirement_likelihood", None) - if likelihood is None: - FancyLogger.get_logger().warning( - f"could not get value from alora requirement output; looking for `requirement_likelihood` in {req_dict}" - ) - return False - - if likelihood > 0.5: - return True - - return False - - -class ALoraRequirement(Requirement, Intrinsic): - """A requirement that always uses an (possibly specified) ALora. If an exception is thrown during the ALora execution path, `mellea` will fall back to LLMaJ. But that is the only case where LLMaJ will be used.""" - - def __init__(self, description: str, intrinsic_name: str | None = None): - """A requirement that is validated by an ALora. - - Args: - description: See `Requirement.__init__` - intrinsic_name: the name of the intrinsic; must match the adapter - """ - # TODO: We may want to actually do the validation_fn here so that we can set the score. - super().__init__( - description, validation_fn=None, output_to_bool=requirement_check_to_bool - ) - self.use_aloras: bool = True - - if intrinsic_name is None: - intrinsic_name = "requirement_check" - - # Initialize the other side of the inheritance tree - Intrinsic.__init__( - self, - intrinsic_name=intrinsic_name, - intrinsic_kwargs={"requirement": f"{self.description}"}, - ) - - -class ScorerRequirement(Requirement): - """A requirement that always returns a non-None score. The scorer must also define a preference ordering to indicate whether the goal is to maximize or minimize the score.""" - - def __init__( - self, - description: str | None = None, - validation_fn: Callable[[Context], ValidationResult] | None = None, - preference_ordering: str = "max", - *, - output_to_bool: Callable[[CBlock | str], bool] | None = default_output_to_bool, - check_only: bool = False, - ): - """A requirement that is validated by an ALora. - - Args: - description: See `Requirement.__init__` - validation_fn: If provided, this function will be executed instead of using LLM-as-a-Judge. This function must return a valid score - preference_ordering: indicates whether the goal is to maximize or minimize the score. must be either "max" or "min". Defaults to None - output_to_bool: See `Requirement.__init__` - check_only: See `Requirement.__init__` - """ - super().__init__( - description, - validation_fn=validation_fn, - output_to_bool=output_to_bool, - check_only=check_only, - ) - - if preference_ordering.lower() not in ["max", "min"]: - raise NotImplementedError - self.preference_ordering: str = preference_ordering.lower() - - async def validate( - self, - backend: Backend, - ctx: Context, - *, - format: type[BaseModelSubclass] | None = None, - model_options: dict | None = None, - ) -> ValidationResult: - """Chooses the appropriate validation strategy and applies that strategy. Asserts that the returned ValidationResult has a valid score.""" - if self.validation_fn is not None: - # Python validation strategy - validation_result = self.validation_fn(ctx) - assert validation_result._score is not None, ( - "ScorerRequirement must have a score that is not None" - ) - return validation_result - else: - # LLMaJ validation strategy. This includes ALora because the backend generate call will appropriately dispatch. - # For ScorerRequirement, provide score of 1 for result=True, 0 for result=False - assert self.output_to_bool is not None - last_output = ctx.last_output() - assert isinstance(last_output, ModelOutputThunk), ( - " Context has no appropriate last output" - ) - - # Create a copy of the requirement that holds the output - # and its template gets populated with the output correctly. - req_copy = copy(self) - req_copy._output = last_output.value - llm_as_a_judge_result, val_ctx = await backend.generate_from_context( - req_copy, ctx, format=format, model_options=model_options - ) - await llm_as_a_judge_result.avalue() - result = self.output_to_bool(llm_as_a_judge_result) - - return ValidationResult( - result=result, - reason=llm_as_a_judge_result.value, - score=1 if result else 0, - thunk=llm_as_a_judge_result, - context=val_ctx, - ) - - -def reqify(r: str | Requirement) -> Requirement: - """Maps strings to Requirements. - - This is a utility method for functions that allow you to pass in Requirements as either explicit Requirement objects or strings that you intend to be interpreted as requirements. - """ - if type(r) is str: - return Requirement(r) - elif isinstance(r, Requirement): - return r - else: - raise Exception(f"reqify takes a str or requirement, not {r}") - - -def req(*args, **kwargs) -> Requirement: - """Shorthand for Requirement.__init__.""" - return Requirement(*args, **kwargs) - - -def check(*args, **kwargs) -> Requirement: - """Shorthand for Requirement.__init__(..., check_only=True).""" - return Requirement(*args, **kwargs, check_only=True) - - -@overload -def simple_validate( - fn: Callable[[str], tuple[bool, str]], -) -> Callable[[Context], ValidationResult]: ... - - -@overload -def simple_validate( - fn: Callable[[str], bool], *, reason: str | None = None -) -> Callable[[Context], ValidationResult]: ... - - -def simple_validate( - fn: Callable[[str], Any], *, reason: str | None = None -) -> Callable[[Context], ValidationResult]: - """Syntactic sugar for writing validation functions that only operate over the last output from the model (interpreted as a string). - - This is useful when your validation logic only depends upon the most recent model output. For example: - - `Requirement("Answer 'yes' or 'no'", simple_validate(lambda x: x == 'yes' or x == 'no')` - - Validation functions operate over `Context`. Often you do not care about the entire context, and just want to consider the most recent output from the model. - - Important notes: - - this operates over the more recent _model output_, not the most recent message. - - Model outputs are sometimes parsed into more complex types (eg by a `Formatter.parse` call or an OutputProcessor). This validation logic will interpret the most recent output as a string, regardless of whether it has a more complex parsed representation. - - Args: - fn: the simple validation function that takes a string and returns either a bool or (bool, str) - reason: only used if the provided function returns a bool; if the validation function fails, a static reason for that failure to give to the llm when repairing - """ - - def validate(ctx: Context) -> ValidationResult: - o = ctx.last_output() - if o is None or o.value is None: - FancyLogger.get_logger().warn( - "Last output of context was None. That might be a problem. We return validation as False to be able to continue..." - ) - return ValidationResult( - False - ) # Don't pass in the static reason since the function didn't run. - - result = fn(o.value) - - # Only confirm that the result conforms to the fn type requirements here. Functions can - # declare return types and then deviate from them. - - # Oneliner that checks the tuple actually contains (bool, str) - if isinstance(result, tuple) and list(map(type, result)) == [bool, str]: - return ValidationResult(result[0], reason=result[1]) - - elif type(result) is bool: - return ValidationResult(result, reason=reason) - - raise ValueError( - f"function {fn.__name__} passed to simple_validate didn't return either bool or [bool, str]; returned {type(result)} instead" - ) - - return validate diff --git a/mellea/stdlib/requirements/__init__.py b/mellea/stdlib/requirements/__init__.py new file mode 100644 index 00000000..9f936d78 --- /dev/null +++ b/mellea/stdlib/requirements/__init__.py @@ -0,0 +1,16 @@ +"""Module for working with Requirements.""" + +# Import from core for ergonomics. +from ...core import Requirement, ValidationResult, default_output_to_bool +from .md import as_markdown_list, is_markdown_list, is_markdown_table +from .python_reqs import PythonExecutionReq +from .requirement import ( + ALoraRequirement, + LLMaJRequirement, + check, + req, + reqify, + requirement_check_to_bool, + simple_validate, +) +from .tool_reqs import tool_arg_validator, uses_tool diff --git a/mellea/stdlib/reqlib/md.py b/mellea/stdlib/requirements/md.py similarity index 95% rename from mellea/stdlib/reqlib/md.py rename to mellea/stdlib/requirements/md.py index 9a1836ed..8d3ed00d 100644 --- a/mellea/stdlib/reqlib/md.py +++ b/mellea/stdlib/requirements/md.py @@ -2,8 +2,7 @@ import mistletoe -from mellea.stdlib.base import Context -from mellea.stdlib.requirement import Requirement +from ...core import Context, Requirement # region lists diff --git a/mellea/stdlib/reqlib/python.py b/mellea/stdlib/requirements/python_reqs.py similarity index 95% rename from mellea/stdlib/reqlib/python.py rename to mellea/stdlib/requirements/python_reqs.py index 1c83a330..c5dd916f 100644 --- a/mellea/stdlib/reqlib/python.py +++ b/mellea/stdlib/requirements/python_reqs.py @@ -1,17 +1,7 @@ """Requirements for Python code generation validation.""" -import ast -import subprocess -import sys -import tempfile -from abc import ABC, abstractmethod -from dataclasses import dataclass -from pathlib import Path -from typing import Any - -from mellea.helpers.fancy_logger import FancyLogger -from mellea.stdlib.base import Context -from mellea.stdlib.requirement import Requirement, ValidationResult +from collections.abc import Callable + from mellea.stdlib.tools.interpreter import ( ExecutionEnvironment, LLMSandboxEnvironment, @@ -19,6 +9,8 @@ UnsafeEnvironment, ) +from ...core import Context, FancyLogger, Requirement, ValidationResult + logger = FancyLogger.get_logger() @@ -196,5 +188,9 @@ def __init__( check_only=True, ) + # Add type hint to validation_fn here. It's always set for this requirement. + self.validation_fn: Callable[[Context], ValidationResult] + assert self.validation_fn is not None + # endregion diff --git a/mellea/stdlib/requirements/requirement.py b/mellea/stdlib/requirements/requirement.py new file mode 100644 index 00000000..17799eb5 --- /dev/null +++ b/mellea/stdlib/requirements/requirement.py @@ -0,0 +1,147 @@ +"""Requirements are a special type of Component used as input to the "validate" step in Instruct/Validate/Repair design patterns.""" + +import json +from collections.abc import Callable +from typing import Any, overload + +from ...core import CBlock, Context, FancyLogger, Requirement, ValidationResult +from ..components.intrinsic import Intrinsic + + +class LLMaJRequirement(Requirement): + """A requirement that always uses LLM-as-a-Judge. Any available constraint ALoRA will be ignored.""" + + use_aloras: bool = False + + +def requirement_check_to_bool(x: CBlock | str) -> bool: + """Checks if a given output should be marked converted to `True`. + + By default, the requirement check alora outputs: `{"requirement_likelihood": 0.0}`. + True if >.5 + """ + output = str(x) + req_dict: dict[str, Any] = json.loads(output) + + likelihood = req_dict.get("requirement_likelihood", None) + if likelihood is None: + FancyLogger.get_logger().warning( + f"could not get value from alora requirement output; looking for `requirement_likelihood` in {req_dict}" + ) + return False + + if likelihood > 0.5: + return True + + return False + + +class ALoraRequirement(Requirement, Intrinsic): + """A requirement that always uses an (possibly specified) ALora. If an exception is thrown during the ALora execution path, `mellea` will fall back to LLMaJ. But that is the only case where LLMaJ will be used.""" + + def __init__(self, description: str, intrinsic_name: str | None = None): + """A requirement that is validated by an ALora. + + Args: + description: See `Requirement.__init__` + intrinsic_name: the name of the intrinsic; must match the adapter + """ + # TODO: We may want to actually do the validation_fn here so that we can set the score. + super().__init__( + description, validation_fn=None, output_to_bool=requirement_check_to_bool + ) + self.use_aloras: bool = True + + if intrinsic_name is None: + intrinsic_name = "requirement_check" + + # Initialize the other side of the inheritance tree + Intrinsic.__init__( + self, + intrinsic_name=intrinsic_name, + intrinsic_kwargs={"requirement": f"{self.description}"}, + ) + + +def reqify(r: str | Requirement) -> Requirement: + """Maps strings to Requirements. + + This is a utility method for functions that allow you to pass in Requirements as either explicit Requirement objects or strings that you intend to be interpreted as requirements. + """ + if type(r) is str: + return Requirement(r) + elif isinstance(r, Requirement): + return r + else: + raise Exception(f"reqify takes a str or requirement, not {r}") + + +def req(*args, **kwargs) -> Requirement: + """Shorthand for Requirement.__init__.""" + return Requirement(*args, **kwargs) + + +def check(*args, **kwargs) -> Requirement: + """Shorthand for Requirement.__init__(..., check_only=True).""" + return Requirement(*args, **kwargs, check_only=True) + + +@overload +def simple_validate( + fn: Callable[[str], tuple[bool, str]], +) -> Callable[[Context], ValidationResult]: ... + + +@overload +def simple_validate( + fn: Callable[[str], bool], *, reason: str | None = None +) -> Callable[[Context], ValidationResult]: ... + + +def simple_validate( + fn: Callable[[str], Any], *, reason: str | None = None +) -> Callable[[Context], ValidationResult]: + """Syntactic sugar for writing validation functions that only operate over the last output from the model (interpreted as a string). + + This is useful when your validation logic only depends upon the most recent model output. For example: + + `Requirement("Answer 'yes' or 'no'", simple_validate(lambda x: x == 'yes' or x == 'no')` + + Validation functions operate over `Context`. Often you do not care about the entire context, and just want to consider the most recent output from the model. + + Important notes: + - this operates over the more recent _model output_, not the most recent message. + - Model outputs are sometimes parsed into more complex types (eg by a `Formatter.parse` call or an OutputProcessor). This validation logic will interpret the most recent output as a string, regardless of whether it has a more complex parsed representation. + + Args: + fn: the simple validation function that takes a string and returns either a bool or (bool, str) + reason: only used if the provided function returns a bool; if the validation function fails, a static reason for that failure to give to the llm when repairing + """ + + def validate(ctx: Context) -> ValidationResult: + o = ctx.last_output() + if o is None or o.value is None: + FancyLogger.get_logger().warn( + "Last output of context was None. That might be a problem. We return validation as False to be able to continue..." + ) + return ValidationResult( + False + ) # Don't pass in the static reason since the function didn't run. + + result = fn(o.value) + + # Only confirm that the result conforms to the fn type requirements here. Functions can + # declare return types and then deviate from them. + + # Oneliner that checks the tuple actually contains (bool, str) + if isinstance(result, tuple) and list(map(type, result)) == [bool, str]: + return ValidationResult(result[0], reason=result[1]) + + elif type(result) is bool: + return ValidationResult(result, reason=reason) + + raise ValueError( + f"function {fn.__name__} passed to simple_validate didn't return either bool or [bool, str]; returned {type(result)} instead" + ) + + return validate diff --git a/mellea/stdlib/safety/__init__.py b/mellea/stdlib/requirements/safety/__init__.py similarity index 100% rename from mellea/stdlib/safety/__init__.py rename to mellea/stdlib/requirements/safety/__init__.py diff --git a/mellea/stdlib/safety/guardian.py b/mellea/stdlib/requirements/safety/guardian.py similarity index 97% rename from mellea/stdlib/safety/guardian.py rename to mellea/stdlib/requirements/safety/guardian.py index 33834163..41f49441 100644 --- a/mellea/stdlib/safety/guardian.py +++ b/mellea/stdlib/requirements/safety/guardian.py @@ -3,12 +3,17 @@ from enum import Enum from typing import Literal -from mellea.backends import Backend, BaseModelSubclass -from mellea.helpers.fancy_logger import FancyLogger -from mellea.stdlib.base import CBlock, ChatContext, Context, ModelOutputThunk -from mellea.stdlib.chat import Message -from mellea.stdlib.instruction import Instruction -from mellea.stdlib.requirement import Requirement, ValidationResult +from ....core import ( + Backend, + BaseModelSubclass, + CBlock, + Context, + FancyLogger, + Requirement, + ValidationResult, +) +from ...components import Message +from ...context import ChatContext class GuardianRisk(Enum): @@ -218,7 +223,7 @@ async def validate( # Try to reuse chat history directly when available. messages = None try: - from mellea.stdlib.chat import as_chat_history + from ...components.chat import as_chat_history messages = as_chat_history(ctx) except Exception: diff --git a/mellea/stdlib/reqlib/tools.py b/mellea/stdlib/requirements/tool_reqs.py similarity index 96% rename from mellea/stdlib/reqlib/tools.py rename to mellea/stdlib/requirements/tool_reqs.py index 6b64f18a..e38c42cf 100644 --- a/mellea/stdlib/reqlib/tools.py +++ b/mellea/stdlib/requirements/tool_reqs.py @@ -1,10 +1,8 @@ """Requirements for tool-use workflows.""" from collections.abc import Callable -from typing import Optional -from mellea.stdlib.base import Context -from mellea.stdlib.requirement import Requirement, ValidationResult +from ...core import Context, Requirement, ValidationResult def _name2str(tool_name: str | Callable) -> str: diff --git a/mellea/stdlib/rewards/__init__.py b/mellea/stdlib/rewards/__init__.py deleted file mode 100644 index 3c1aede3..00000000 --- a/mellea/stdlib/rewards/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Components used with reward models.""" diff --git a/mellea/stdlib/rewards/prm_scorer.py b/mellea/stdlib/rewards/prm_scorer.py deleted file mode 100644 index 5653cd99..00000000 --- a/mellea/stdlib/rewards/prm_scorer.py +++ /dev/null @@ -1,57 +0,0 @@ -"""PRM Requirements.""" - -from mellea.backends.huggingface import HFProcessRewardModel -from mellea.stdlib.base import CBlock, Context -from mellea.stdlib.chat import Message -from mellea.stdlib.requirement import ScorerRequirement, ValidationResult - - -class PRMScorer(ScorerRequirement): - """A process reward model scorer based on local huggingface backend.""" - - def __init__( - self, *, prm_model: HFProcessRewardModel, preference_ordering: str = "max" - ): - """Instantiate a process reward model scorer based on local huggingface backend. - - Args: - prm_model: The PRM model - preference_ordering: indicates whether the goal is to maximize or minimize the score. must be either "max" or "min". - """ - super().__init__( - check_only=True, - validation_fn=lambda c: self._prm_validate(c), - preference_ordering=preference_ordering, - ) - - self.model: HFProcessRewardModel = prm_model - - def _prm_validate(self, ctx: Context): - """Returns PRM score of last turn of context.""" - last_turn = ctx.last_turn() - assert last_turn is not None - - # This requirement can handle only complete turns with both - # a user message and an assistant message - - assert last_turn.model_input is not None and last_turn.output is not None - assert last_turn.output.value is not None - - user_msg = last_turn.model_input - - # Handle the variety of possible user input. - if isinstance(user_msg, CBlock) and user_msg.value is not None: - user_query = user_msg.value - elif isinstance(user_msg, Message) and user_msg.content != "": - user_query = user_msg.content - else: - user_query = str(user_msg) - - assistant_content = last_turn.output.value - - rewards, rewards_per_step = self.model.score(user_query, assistant_content) - - # return single reward item for the response - assert len(rewards) == 1 - - return ValidationResult(result=True, reason=None, score=rewards[0]) diff --git a/mellea/stdlib/sampling/__init__.py b/mellea/stdlib/sampling/__init__.py index 81bc507c..2ad920b2 100644 --- a/mellea/stdlib/sampling/__init__.py +++ b/mellea/stdlib/sampling/__init__.py @@ -1,9 +1,10 @@ """sampling methods go here.""" +# Import from core for ergonomics. +from ...core import SamplingResult, SamplingStrategy from .base import ( BaseSamplingStrategy, MultiTurnStrategy, RejectionSamplingStrategy, RepairTemplateStrategy, ) -from .types import S, SamplingResult, SamplingStrategy diff --git a/mellea/stdlib/sampling/base.py b/mellea/stdlib/sampling/base.py index 6c00a008..d89cb074 100644 --- a/mellea/stdlib/sampling/base.py +++ b/mellea/stdlib/sampling/base.py @@ -5,15 +5,22 @@ import tqdm -import mellea.stdlib.functional as mfuncs -from mellea.backends import Backend, BaseModelSubclass -from mellea.helpers.fancy_logger import FancyLogger -from mellea.stdlib.base import CBlock, ChatContext, Component, Context, ModelOutputThunk -from mellea.stdlib.chat import Message -from mellea.stdlib.instruction import Instruction -from mellea.stdlib.requirement import Requirement, ValidationResult - -from .types import S, SamplingResult, SamplingStrategy +from ...core import ( + Backend, + BaseModelSubclass, + Component, + Context, + FancyLogger, + ModelOutputThunk, + Requirement, + S, + SamplingResult, + SamplingStrategy, + ValidationResult, +) +from ...stdlib import functional as mfuncs +from ..components import Instruction, Message +from ..context import ChatContext class BaseSamplingStrategy(SamplingStrategy): diff --git a/mellea/stdlib/sampling/best_of_n.py b/mellea/stdlib/sampling/best_of_n.py deleted file mode 100644 index b0b827ef..00000000 --- a/mellea/stdlib/sampling/best_of_n.py +++ /dev/null @@ -1,303 +0,0 @@ -"""Best of N Sampling Strategy.""" - -from copy import deepcopy - -import tqdm - -import mellea.stdlib.functional as mfuncs -from mellea.backends import Backend, BaseModelSubclass -from mellea.helpers.async_helpers import wait_for_all_mots -from mellea.helpers.fancy_logger import FancyLogger -from mellea.stdlib.base import CBlock, ChatContext, Component, Context, ModelOutputThunk -from mellea.stdlib.instruction import Instruction -from mellea.stdlib.requirement import Requirement, ScorerRequirement, ValidationResult -from mellea.stdlib.sampling import BaseSamplingStrategy, SamplingResult -from mellea.stdlib.sampling.types import S - - -class BestofNSamplingStrategy(BaseSamplingStrategy): - """Sampling strategy that selects the best response from a set of samples as given by a Requirement Scorer.""" - - async def sample( - self, - action: Component[S], - context: Context, - backend: Backend, - requirements: list[Requirement] | None, - *, - validation_ctx: Context | None = None, - format: type[BaseModelSubclass] | None = None, - model_options: dict | None = None, - tool_calls: bool = False, - show_progress: bool = True, - ) -> SamplingResult[S]: - """This method performs a sampling operation based on the given instruction. - - Args: - action : The action object to be sampled. - context: The context to be passed to the sampling strategy. - backend: The backend used for generating samples. - requirements: List of requirements to test against (merged with global requirements). - validation_ctx: Optional context to use for validation. If None, validation_ctx = ctx. - format: output format for structured outputs. - model_options: model options to pass to the backend during generation / validation. - tool_calls: True if tool calls should be used during this sampling strategy. - show_progress: if true, a tqdm progress bar is used. Otherwise, messages will still be sent to flog. - - Returns: - SamplingResult: A result object indicating the success or failure of the sampling process. - - Raises: - AssertionError: Asserts that all required components (repair, select_from_failure, validate, and generate) are provided before proceeding with the sampling. - """ - validation_ctx = validation_ctx if validation_ctx is not None else context - assert validation_ctx is not None, "Validation context must be provided." - - flog = FancyLogger.get_logger() - - sampled_results: list[ModelOutputThunk] = [] - sampled_scores: list[list[tuple[Requirement, ValidationResult]]] = [] - sampled_actions: list[Component] = [] - sample_contexts: list[Context] = [] - - successful_sampled_results: list[ModelOutputThunk] = [] - successful_sampled_scores: list[list[tuple[Requirement, ValidationResult]]] = [] - successful_sampled_actions: list[Component] = [] - successful_sample_contexts: list[Context] = [] - - # The `logging_redirect_tqdm` approach did not work, so instead we will use the show_progress - # flag to determine whether we should show the pbar. - show_progress = show_progress and flog.getEffectiveLevel() <= FancyLogger.INFO - - reqs = [] - if self.requirements is not None: - reqs += self.requirements - elif requirements is not None: - reqs += requirements - - reqs = list(set(reqs)) - - # check that there is exactly one ScorerRequirement - scorer_requirements = 0 - for req in reqs: - # strict typecheck for scorer requirement - if isinstance(req, ScorerRequirement): - scorer_requirements += 1 - - assert scorer_requirements == 1, ( - "BestOfNSamplingStrategy requires exactly one ScorerRequirement" - ) - - loop_count = 0 - generate_loop_budget_iterator = ( - tqdm.tqdm(range(self.loop_budget)) # type: ignore - if show_progress - else range(self.loop_budget) # type: ignore - ) - validate_loop_budget_iterator = ( - tqdm.tqdm(range(self.loop_budget)) # type: ignore - if show_progress - else range(self.loop_budget) # type: ignore - ) - - next_action = deepcopy(action) - next_context = context - flog.info("BestofNSampling Generating Loop:") - for _ in generate_loop_budget_iterator: # type: ignore - loop_count += 1 - if not show_progress: - flog.info(f"Running loop {loop_count} of {self.loop_budget}") - - # run a generation pass - result, result_ctx = await backend.generate_from_context( - next_action, - ctx=next_context, - format=format, - model_options=model_options, - tool_calls=tool_calls, - ) - - # Sampling strategies may use different components from the original - # action. This might cause discrepancies in the expected parsed_repr - # type / value. Explicitly overwrite that here. - result.parsed_repr = action.parse(result) - - sampled_results.append(result) - sampled_actions.append(next_action) - sample_contexts.append(result_ctx) - - await wait_for_all_mots(sampled_results) - - flog.info("BestofNSampling Validation Loop:") - for i in validate_loop_budget_iterator: - result_ctx = sample_contexts[i] - result = sampled_results[i] - next_action = sampled_actions[i] - - val_scores_co = mfuncs.avalidate( - reqs=reqs, - context=result_ctx, - backend=backend, - output=result, - format=None, - model_options=model_options, - input=next_action._description, # type: ignore - # tool_calls=tool_calls # Don't support using tool calls in validation strategies. - ) - val_scores = await val_scores_co - - # match up reqs with scores - constraint_scores = list(zip(reqs, val_scores)) - - # collect all data - sampled_scores.append(constraint_scores) - - # check if requirements pass else repair and re-sample - # if all vals are true, save it and continue to get next sample - if all(bool(s[1]) for s in constraint_scores): - flog.info("SUCCESS") - assert ( - result._generate_log is not None - ) # Cannot be None after generation. - result._generate_log.is_final_result = True - - successful_sampled_results.append(result) - successful_sampled_scores.append(constraint_scores) - successful_sampled_actions.append(next_action) - successful_sample_contexts.append(result_ctx) - - else: - # log partial success and continue - count_valid = len([s for s in constraint_scores if bool(s[1])]) - flog.info(f"FAILED. Valid: {count_valid}/{len(constraint_scores)}") - - # If we did not pass all constraints, update the instruction and try again. - next_action, next_context = self.repair( - next_context, - result_ctx, - sampled_actions, - sampled_results, - sampled_scores, - ) - - # find max reward amongst results for which all requirements have passed - if len(successful_sampled_scores) > 0: - scores: list[float] = [] - scorer_preference_ordering = None - - for sample in successful_sampled_scores: - for req, val_score in sample: - if isinstance(req, ScorerRequirement): - assert val_score._score is not None - scores.append(val_score._score) - scorer_preference_ordering = req.preference_ordering - - assert len(successful_sampled_results) == len(scores) - assert scorer_preference_ordering is not None - - if scorer_preference_ordering == "max": - best_result, best_score, best_context = max( - zip(successful_sampled_results, scores, successful_sample_contexts), - key=lambda x: x[1], - ) - elif scorer_preference_ordering == "min": - best_result, best_score, best_context = min( - zip(successful_sampled_results, scores, successful_sample_contexts), - key=lambda x: x[1], - ) - else: - raise NotImplementedError - - best_index = sampled_results.index(best_result) - - return SamplingResult( - result_index=best_index, - success=True, - sample_generations=sampled_results, - sample_validations=sampled_scores, - sample_actions=sampled_actions, - sample_contexts=sample_contexts, - ) - - # if all failures, call select from failure - else: - flog.info( - f"Invoking select_from_failure after {len(sampled_results)} failed attempts." - ) - - # if no valid result could be determined, find a last resort. - best_failed_index = self.select_from_failure( - sampled_actions, sampled_results, sampled_scores - ) - assert best_failed_index < len(sampled_results), ( - "The select_from_failure method did not return a valid result. It has to selected from failed_results." - ) - return SamplingResult( - result_index=best_failed_index, - success=False, - sample_generations=sampled_results, - sample_validations=sampled_scores, - sample_actions=sampled_actions, - sample_contexts=sample_contexts, - ) - - @staticmethod - def select_from_failure( - sampled_actions: list[Component], - sampled_results: list[ModelOutputThunk], - sampled_val: list[list[tuple[Requirement, ValidationResult]]], - ) -> int: - """Selects the attempt with the highest score. - - Args: - sampled_actions: List of actions that have been executed (without success). - sampled_results: List of (unsuccessful) generation results for these actions. - sampled_val: List of validation results for the results. - - Returns: - The index of the result that should be selected as `.value`. - """ - scores: list[float | None] = [] - - for sample in sampled_val: - for req, val_score in sample: - if isinstance(req, ScorerRequirement): - assert val_score._score is not None - scores.append(val_score._score) - - assert len(sampled_results) == len(scores) - - return scores.index(max(scores)) # type: ignore - - @staticmethod - def repair( - old_ctx: Context, - new_ctx: Context, - past_actions: list[Component], - past_results: list[ModelOutputThunk], - past_val: list[list[tuple[Requirement, ValidationResult]]], - ) -> tuple[Component, Context]: - """Adds a description of the requirements that failed to a copy of the original instruction. - - Args: - old_ctx: The context WITHOUT the last action + output. - new_ctx: The context including the last action + output. - past_actions: List of actions that have been executed (without success). - past_results: List of (unsuccessful) generation results for these actions. - past_val: List of validation results for the results. - - Returns: - The next action component and context to be used for the next generation attempt. - """ - pa = past_actions[-1] - if isinstance(pa, Instruction): - last_failed_reqs: list[Requirement] = [ - s[0] for s in past_val[-1] if not s[1] - ] - last_failed_reqs_str = "* " + "\n* ".join( - [str(r.description) for r in last_failed_reqs] - ) - return pa.copy_and_repair( - repair_string=f"The following requirements failed before:\n{last_failed_reqs_str}" - ), old_ctx - return past_actions[-1], old_ctx diff --git a/mellea/stdlib/sampling/budget_forcing.py b/mellea/stdlib/sampling/budget_forcing.py index 90934938..f59aa85b 100644 --- a/mellea/stdlib/sampling/budget_forcing.py +++ b/mellea/stdlib/sampling/budget_forcing.py @@ -4,16 +4,22 @@ import tqdm -from mellea.backends import Backend, BaseModelSubclass -from mellea.backends.ollama import OllamaModelBackend -from mellea.helpers.fancy_logger import FancyLogger -from mellea.stdlib import functional as mfuncs -from mellea.stdlib.base import ModelOutputThunk -from mellea.stdlib.requirement import Requirement, ValidationResult -from mellea.stdlib.sampling import RejectionSamplingStrategy, SamplingResult -from mellea.stdlib.sampling.base import Component, Context -from mellea.stdlib.sampling.types import S -from mellea.stdlib.sampling_algos.budget_forcing_alg import think_budget_forcing +from ...backends.ollama import OllamaModelBackend +from ...core import ( + Backend, + BaseModelSubclass, + Component, + Context, + FancyLogger, + ModelOutputThunk, + Requirement, + S, + SamplingResult, + ValidationResult, +) +from ...stdlib import functional as mfuncs +from .base import RejectionSamplingStrategy +from .sampling_algos import think_budget_forcing class BudgetForcingSamplingStrategy(RejectionSamplingStrategy): diff --git a/mellea/stdlib/sampling/majority_voting.py b/mellea/stdlib/sampling/majority_voting.py index d68d06c0..d770d652 100644 --- a/mellea/stdlib/sampling/majority_voting.py +++ b/mellea/stdlib/sampling/majority_voting.py @@ -7,11 +7,16 @@ from math_verify import ExprExtractionConfig, LatexExtractionConfig, parse, verify from rouge_score.rouge_scorer import RougeScorer # codespell:ignore -from mellea.backends import Backend, BaseModelSubclass -from mellea.stdlib.requirement import Requirement -from mellea.stdlib.sampling import RejectionSamplingStrategy, SamplingResult -from mellea.stdlib.sampling.base import Component, Context -from mellea.stdlib.sampling.types import S +from ...core import ( + Backend, + BaseModelSubclass, + Component, + Context, + Requirement, + S, + SamplingResult, +) +from .base import RejectionSamplingStrategy class BaseMBRDSampling(RejectionSamplingStrategy): diff --git a/mellea/stdlib/sampling/sampling_algos/__init__.py b/mellea/stdlib/sampling/sampling_algos/__init__.py new file mode 100644 index 00000000..6da5e6a5 --- /dev/null +++ b/mellea/stdlib/sampling/sampling_algos/__init__.py @@ -0,0 +1,3 @@ +"""Module for Sampling Algorithms.""" + +from .budget_forcing_alg import think_budget_forcing diff --git a/mellea/stdlib/sampling_algos/budget_forcing_alg.py b/mellea/stdlib/sampling/sampling_algos/budget_forcing_alg.py similarity index 95% rename from mellea/stdlib/sampling_algos/budget_forcing_alg.py rename to mellea/stdlib/sampling/sampling_algos/budget_forcing_alg.py index 2f0a4c01..12ea3f5a 100644 --- a/mellea/stdlib/sampling_algos/budget_forcing_alg.py +++ b/mellea/stdlib/sampling/sampling_algos/budget_forcing_alg.py @@ -1,11 +1,17 @@ """Budget forcing implementation.""" -import re from typing import Any -from mellea.backends import BaseModelSubclass, ModelOption -from mellea.backends.ollama import OllamaModelBackend -from mellea.stdlib.base import CBlock, Component, Context, GenerateLog, ModelOutputThunk +from ....backends import ModelOption +from ....backends.ollama import OllamaModelBackend +from ....core import ( + BaseModelSubclass, + CBlock, + Component, + Context, + GenerateLog, + ModelOutputThunk, +) async def think_budget_forcing( # noqa: D417 @@ -106,12 +112,12 @@ async def think_budget_forcing( # noqa: D417 if end_think_token: step = response.split(end_think_token)[0] # model fails to produce thoughts, let's exit - if len(step.strip()) <= min_char_len: + if len(step.strip()) <= min_char_len: # type: ignore responses.append(response) break # request more steps - step = f"{step} {think_more_suffix}" + step = f"{step} {think_more_suffix}" # type: ignore responses.append(step) curr_prompt += step diff --git a/mellea/stdlib/session.py b/mellea/stdlib/session.py index ef2b348b..90fc6c31 100644 --- a/mellea/stdlib/session.py +++ b/mellea/stdlib/session.py @@ -5,34 +5,31 @@ import contextvars import inspect from copy import copy -from typing import Any, Literal, TypeVar, overload +from typing import Any, Literal, overload from PIL import Image as PILImage -import mellea.stdlib.functional as mfuncs -from mellea.backends import Backend, BaseModelSubclass -from mellea.backends.model_ids import ( - IBM_GRANITE_3_3_8B, - IBM_GRANITE_4_MICRO_3B, - ModelIdentifier, -) -from mellea.backends.ollama import OllamaModelBackend -from mellea.backends.openai import OpenAIBackend -from mellea.helpers.fancy_logger import FancyLogger -from mellea.stdlib.base import ( +from ..backends.model_ids import IBM_GRANITE_4_MICRO_3B, ModelIdentifier +from ..core import ( + Backend, + BaseModelSubclass, CBlock, Component, Context, + FancyLogger, GenerateLog, ImageBlock, ModelOutputThunk, + Requirement, S, - SimpleContext, + SamplingResult, + SamplingStrategy, + ValidationResult, ) -from mellea.stdlib.chat import Message -from mellea.stdlib.requirement import Requirement, ValidationResult -from mellea.stdlib.sampling import SamplingResult, SamplingStrategy -from mellea.stdlib.sampling.base import RejectionSamplingStrategy +from ..stdlib import functional as mfuncs +from .components import Message +from .context import SimpleContext +from .sampling import RejectionSamplingStrategy # Global context variable for the context session _context_session: contextvars.ContextVar[MelleaSession | None] = contextvars.ContextVar( @@ -57,19 +54,23 @@ def get_session() -> MelleaSession: def backend_name_to_class(name: str) -> Any: """Resolves backend names to Backend classes.""" if name == "ollama": + from ..backends.ollama import OllamaModelBackend + return OllamaModelBackend elif name == "hf" or name == "huggingface": from mellea.backends.huggingface import LocalHFBackend return LocalHFBackend elif name == "openai": + from ..backends.openai import OpenAIBackend + return OpenAIBackend elif name == "watsonx": - from mellea.backends.watsonx import WatsonxAIBackend + from ..backends.watsonx import WatsonxAIBackend return WatsonxAIBackend elif name == "litellm": - from mellea.backends.litellm import LiteLLMBackend + from ..backends.litellm import LiteLLMBackend return LiteLLMBackend else: diff --git a/mellea/stdlib/tools/__init__.py b/mellea/stdlib/tools/__init__.py index 24ca99aa..49096051 100644 --- a/mellea/stdlib/tools/__init__.py +++ b/mellea/stdlib/tools/__init__.py @@ -1,3 +1,3 @@ """Implementations of tools.""" -from mellea.stdlib.tools.interpreter import code_interpreter, local_code_interpreter +from .interpreter import code_interpreter, local_code_interpreter diff --git a/mellea/stdlib/tools/interpreter.py b/mellea/stdlib/tools/interpreter.py index 9f144057..d7bc40e8 100644 --- a/mellea/stdlib/tools/interpreter.py +++ b/mellea/stdlib/tools/interpreter.py @@ -9,9 +9,7 @@ from pathlib import Path from typing import Any -from mellea.helpers.fancy_logger import FancyLogger -from mellea.stdlib.base import Context -from mellea.stdlib.requirement import Requirement, ValidationResult +from ...core import FancyLogger logger = FancyLogger.get_logger() diff --git a/pyproject.toml b/pyproject.toml index eba285ad..913258b4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,7 +42,7 @@ dependencies = [ "mistletoe>=1.4.0", "huggingface-hub>=0.33.4", "pillow", - "granite-common>=0.3.6", # Needed for Intrinsics. + "granite-common>=0.3.6", # Needed for Intrinsics (HF and OpenAI Backends). "math_verify", # Needed for Majority Voting Sampling Strategies. "rouge_score", # Needed for Majority Voting Sampling Strategies. "llm-sandbox[docker]>=0.3.23", diff --git a/test/backends/test_adapters/test_adapter.py b/test/backends/test_adapters/test_adapter.py index f5abbb84..632cf7cf 100644 --- a/test/backends/test_adapters/test_adapter.py +++ b/test/backends/test_adapters/test_adapter.py @@ -1,7 +1,7 @@ import pathlib import pytest -from mellea.backends.adapters.adapter import GraniteCommonAdapter +from mellea.backends.adapters import GraniteCommonAdapter # The backend tests handle most of the adapter testing. Do a basic test here diff --git a/test/backends/test_huggingface.py b/test/backends/test_huggingface.py index 328b6068..9c9785a3 100644 --- a/test/backends/test_huggingface.py +++ b/test/backends/test_huggingface.py @@ -12,27 +12,23 @@ from typing_extensions import Annotated from mellea import MelleaSession -from mellea.backends.adapters.adapter import GraniteCommonAdapter +from mellea.backends.adapters import GraniteCommonAdapter from mellea.backends.cache import SimpleLRUCache -from mellea.backends.formatter import TemplateFormatter +from mellea.formatters import TemplateFormatter from mellea.backends.huggingface import LocalHFBackend, _assert_correct_adapters -from mellea.backends.types import ModelOption -from mellea.stdlib.base import ( +from mellea.backends import ModelOption +from mellea.core import ( CBlock, - ChatContext, Context, ModelOutputThunk, - SimpleContext, -) -from mellea.stdlib.chat import Message -from mellea.stdlib.intrinsics.intrinsic import Intrinsic -from mellea.stdlib.requirement import ( - ALoraRequirement, - LLMaJRequirement, - Requirement, ValidationResult, default_output_to_bool, ) +from mellea.stdlib.context import ChatContext, SimpleContext + +from mellea.stdlib.components import Message +from mellea.stdlib.components import Intrinsic +from mellea.stdlib.requirements import ALoraRequirement, LLMaJRequirement @pytest.fixture(scope="module") @@ -142,7 +138,7 @@ def test_constraint_lora_override_does_not_override_alora(session, backend): # the correct actions / results in it. assert isinstance(val_result.context, Context) assert isinstance(val_result.thunk, ModelOutputThunk) - assert isinstance(val_result.context.previous_node.node_data, ALoraRequirement) + assert isinstance(val_result.context.previous_node.node_data, ALoraRequirement) # type: ignore assert val_result.context.node_data is val_result.thunk backend.default_to_constraint_checking_alora = True diff --git a/test/backends/test_huggingface_tools.py b/test/backends/test_huggingface_tools.py index 0df5f3dc..630e8583 100644 --- a/test/backends/test_huggingface_tools.py +++ b/test/backends/test_huggingface_tools.py @@ -1,21 +1,11 @@ -import pydantic import pytest -from typing_extensions import Annotated import mellea.backends.model_ids as model_ids from mellea import MelleaSession from mellea.backends.cache import SimpleLRUCache -from mellea.backends.formatter import TemplateFormatter from mellea.backends.huggingface import LocalHFBackend -from mellea.backends.types import ModelOption -from mellea.stdlib.base import CBlock, ChatContext -from mellea.stdlib.requirement import ( - ALoraRequirement, - LLMaJRequirement, - Requirement, - ValidationResult, - default_output_to_bool, -) +from mellea.backends import ModelOption +from mellea.stdlib.context import ChatContext @pytest.fixture(scope="module") diff --git a/test/backends/test_litellm_ollama.py b/test/backends/test_litellm_ollama.py index 392dbde2..bb2d3316 100644 --- a/test/backends/test_litellm_ollama.py +++ b/test/backends/test_litellm_ollama.py @@ -5,8 +5,9 @@ from mellea import MelleaSession, generative from mellea.backends import ModelOption from mellea.backends.litellm import LiteLLMBackend -from mellea.stdlib.base import CBlock, SimpleContext -from mellea.stdlib.chat import Message +from mellea.core import CBlock +from mellea.stdlib.context import SimpleContext +from mellea.stdlib.components import Message from mellea.stdlib.sampling import RejectionSamplingStrategy from mellea.backends import model_ids @@ -129,6 +130,7 @@ def test_gen_slot(session): @generative def is_happy(text: str) -> bool: """Determine if text is of happy mood.""" + ... h = is_happy(session, text="I'm enjoying life.") diff --git a/test/backends/test_litellm_watsonx.py b/test/backends/test_litellm_watsonx.py index 352cec57..6116c428 100644 --- a/test/backends/test_litellm_watsonx.py +++ b/test/backends/test_litellm_watsonx.py @@ -3,7 +3,7 @@ from mellea import MelleaSession from mellea.backends.litellm import LiteLLMBackend -from mellea.stdlib.base import CBlock +from mellea.core import CBlock @pytest.fixture(scope="function") diff --git a/test/backends/test_types.py b/test/backends/test_model_options.py similarity index 98% rename from test/backends/test_types.py rename to test/backends/test_model_options.py index f02d9eaf..5a75d7f3 100644 --- a/test/backends/test_types.py +++ b/test/backends/test_model_options.py @@ -1,5 +1,5 @@ import pytest -from mellea.backends.types import ModelOption +from mellea.backends import ModelOption def test_model_option_remove(): diff --git a/test/backends/test_ollama.py b/test/backends/test_ollama.py index 760b1edc..20636e8b 100644 --- a/test/backends/test_ollama.py +++ b/test/backends/test_ollama.py @@ -7,9 +7,10 @@ from mellea import start_session from mellea.backends.ollama import OllamaModelBackend -from mellea.backends.types import ModelOption -from mellea.stdlib.base import CBlock, SimpleContext -from mellea.stdlib.requirement import Requirement, simple_validate +from mellea.backends import ModelOption +from mellea.core import CBlock, Requirement +from mellea.stdlib.context import SimpleContext +from mellea.stdlib.requirements import simple_validate @pytest.fixture(scope="function") @@ -45,7 +46,7 @@ def test_instruct_with_requirement(session): happy_tone_req = Requirement( "The email should sound happy in tone.", - output_to_bool=lambda x: "happy" in x.value, + output_to_bool=lambda x: "happy" in x.value, # type: ignore ) sad_tone_req = Requirement("The email should sound sad in tone.") diff --git a/test/backends/test_openai_ollama.py b/test/backends/test_openai_ollama.py index 57ca3281..5f985203 100644 --- a/test/backends/test_openai_ollama.py +++ b/test/backends/test_openai_ollama.py @@ -6,14 +6,14 @@ import openai import pydantic import pytest -from typing_extensions import Annotated from mellea import MelleaSession -from mellea.backends.formatter import TemplateFormatter +from mellea.formatters import TemplateFormatter from mellea.backends.model_ids import META_LLAMA_3_2_1B from mellea.backends.openai import OpenAIBackend -from mellea.backends.types import ModelOption -from mellea.stdlib.base import CBlock, ChatContext, ModelOutputThunk, SimpleContext +from mellea.backends import ModelOption +from mellea.core import CBlock, ModelOutputThunk +from mellea.stdlib.context import ChatContext, SimpleContext @pytest.fixture(scope="module") diff --git a/test/backends/test_openai_vllm/test_openai_vllm.py b/test/backends/test_openai_vllm/test_openai_vllm.py index 30f17a86..d30c7f7f 100644 --- a/test/backends/test_openai_vllm/test_openai_vllm.py +++ b/test/backends/test_openai_vllm/test_openai_vllm.py @@ -6,17 +6,13 @@ from typing_extensions import Annotated from mellea import MelleaSession -from mellea.backends.adapters.adapter import GraniteCommonAdapter -from mellea.backends.formatter import TemplateFormatter +from mellea.backends.adapters import GraniteCommonAdapter +from mellea.formatters import TemplateFormatter from mellea.backends.openai import OpenAIBackend -from mellea.backends.types import ModelOption, _ServerType -from mellea.stdlib.base import CBlock, ChatContext, Context, ModelOutputThunk -from mellea.stdlib.requirement import ( - ALoraRequirement, - LLMaJRequirement, - Requirement, - req, -) +from mellea.backends import ModelOption +from mellea.core import CBlock, Context, ModelOutputThunk +from mellea.stdlib.context import ChatContext +from mellea.stdlib.requirements import ALoraRequirement, LLMaJRequirement # The vllm tests are disabled by default, because we need a test environment with the vLLM server running. # We use an env var VLLM_TESTS_ENABLED to enable these tests. @@ -91,7 +87,7 @@ class Email(pydantic.BaseModel): ) print("Formatted output:") email = Email.model_validate_json( - output.value + output.value # type: ignore ) # this should succeed because the output should be JSON because we passed in a format= argument... print(email) @@ -128,7 +124,7 @@ class Answer(pydantic.BaseModel): random_result = results[0] try: - answer = Answer.model_validate_json(random_result.value) + answer = Answer.model_validate_json(random_result.value) # type: ignore except pydantic.ValidationError as e: assert False, ( f"formatting directive failed for {random_result.value}: {e.json()}" @@ -225,7 +221,8 @@ def test_constraint_lora_override_does_not_override_alora(self): assert isinstance(non_alora_output.context, Context) assert isinstance(non_alora_output.thunk, ModelOutputThunk) assert isinstance( - non_alora_output.context.previous_node.node_data, ALoraRequirement + non_alora_output.context.previous_node.node_data, + ALoraRequirement, # type: ignore ) assert non_alora_output.context.node_data is non_alora_output.thunk @@ -284,7 +281,7 @@ class Email(pydantic.BaseModel): ) print("Formatted output:") email = Email.model_validate_json( - output.value + output.value # type: ignore ) # this should succeed because the output should be JSON because we passed in a format= argument... print(email) diff --git a/test/test_tool_calls.py b/test/backends/test_tool_calls.py similarity index 89% rename from test/test_tool_calls.py rename to test/backends/test_tool_calls.py index 031b14bd..7cdb292d 100644 --- a/test/test_tool_calls.py +++ b/test/backends/test_tool_calls.py @@ -1,20 +1,15 @@ import pytest -from mellea.backends import Backend from mellea.backends.ollama import OllamaModelBackend from mellea.backends.tools import ( add_tools_from_context_actions, add_tools_from_model_options, ) -from mellea.backends.types import ModelOption -from mellea.stdlib.base import ( - CBlock, - Component, - ModelOutputThunk, - TemplateRepresentation, - ChatContext, -) -from mellea.stdlib.docs.richdocument import Table +from mellea.backends import ModelOption +from mellea.core import ModelOutputThunk +from mellea.stdlib.context import ChatContext + +from mellea.stdlib.components.docs import Table from mellea.stdlib.session import MelleaSession diff --git a/test/backends/test_tool_helpers.py b/test/backends/test_tool_helpers.py index d17d6b16..4441e885 100644 --- a/test/backends/test_tool_helpers.py +++ b/test/backends/test_tool_helpers.py @@ -1,16 +1,10 @@ -from typing import Any import pytest from mellea.backends.tools import ( add_tools_from_context_actions, add_tools_from_model_options, ) -from mellea.backends.types import ModelOption -from mellea.stdlib.base import ( - CBlock, - Component, - ModelOutputThunk, - TemplateRepresentation, -) +from mellea.backends import ModelOption +from mellea.core import CBlock, Component, ModelOutputThunk, TemplateRepresentation class FakeToolComponent(Component[str]): diff --git a/test/stdlib_basics/test_vision_ollama.py b/test/backends/test_vision_ollama.py similarity index 89% rename from test/stdlib_basics/test_vision_ollama.py rename to test/backends/test_vision_ollama.py index eae4e87b..27a44763 100644 --- a/test/stdlib_basics/test_vision_ollama.py +++ b/test/backends/test_vision_ollama.py @@ -1,5 +1,4 @@ import base64 -import os from io import BytesIO import numpy as np @@ -8,9 +7,9 @@ from mellea import start_session, MelleaSession from mellea.backends import ModelOption -from mellea.stdlib.base import ImageBlock, ModelOutputThunk -from mellea.stdlib.chat import Message -from mellea.stdlib.instruction import Instruction +from mellea.core import ImageBlock, ModelOutputThunk +from mellea.stdlib.components import Message +from mellea.stdlib.components import Instruction @pytest.fixture(scope="module") @@ -76,21 +75,21 @@ def test_image_block_in_instruction( # if not on GH if not gh_run == 1: - assert "yes" in instr.value.lower() or "no" in instr.value.lower() + assert "yes" in instr.value.lower() or "no" in instr.value.lower() # type: ignore # make sure you get the last action turn = m_session.ctx.last_turn() assert turn is not None last_action = turn.model_input assert isinstance(last_action, Instruction) - assert len(last_action._images) > 0 + assert len(last_action._images) > 0 # type: ignore # first image in image list should be the same as the image block - image0 = last_action._images[0] + image0 = last_action._images[0] # type: ignore assert image0 == image_block # get prompt message - lp = turn.output._generate_log.prompt + lp = turn.output._generate_log.prompt # type: ignore assert isinstance(lp, list) assert len(lp) == 1 @@ -130,14 +129,14 @@ def test_image_block_in_chat( assert turn is not None last_action = turn.model_input assert isinstance(last_action, Message) - assert len(last_action.images) > 0 + assert len(last_action.images) > 0 # type: ignore # first image in image list should be the same as the image block - image0_str = last_action.images[0] + image0_str = last_action.images[0] # type: ignore assert image0_str == ImageBlock.from_pil_image(pil_image)._value # get prompt message - lp = turn.output._generate_log.prompt + lp = turn.output._generate_log.prompt # type: ignore assert isinstance(lp, list) assert len(lp) == 1 diff --git a/test/stdlib_basics/test_vision_openai.py b/test/backends/test_vision_openai.py similarity index 90% rename from test/stdlib_basics/test_vision_openai.py rename to test/backends/test_vision_openai.py index c922acd5..19653269 100644 --- a/test/stdlib_basics/test_vision_openai.py +++ b/test/backends/test_vision_openai.py @@ -8,9 +8,9 @@ from mellea import start_session, MelleaSession from mellea.backends import ModelOption -from mellea.stdlib.base import ImageBlock, ModelOutputThunk -from mellea.stdlib.chat import Message -from mellea.stdlib.instruction import Instruction +from mellea.core import ImageBlock, ModelOutputThunk +from mellea.stdlib.components import Message +from mellea.stdlib.components import Instruction @pytest.fixture(scope="module") @@ -80,21 +80,21 @@ def test_image_block_in_instruction( # if not on GH if not gh_run == 1: - assert "yes" in instr.value.lower() or "no" in instr.value.lower() + assert "yes" in instr.value.lower() or "no" in instr.value.lower() # type: ignore # make sure you get the last action turn = m_session.ctx.last_turn() assert turn is not None last_action = turn.model_input assert isinstance(last_action, Instruction) - assert len(last_action._images) > 0 + assert len(last_action._images) > 0 # type: ignore # first image in image list should be the same as the image block - image0 = last_action._images[0] + image0 = last_action._images[0] # type: ignore assert image0 == image_block # get prompt message - lp = turn.output._generate_log.prompt + lp = turn.output._generate_log.prompt # type: ignore assert isinstance(lp, list) assert len(lp) == 1 @@ -140,14 +140,14 @@ def test_image_block_in_chat( assert turn is not None last_action = turn.model_input assert isinstance(last_action, Message) - assert len(last_action.images) > 0 + assert len(last_action.images) > 0 # type: ignore # first image in image list should be the same as the image block - image0_str = last_action.images[0] + image0_str = last_action.images[0] # type: ignore assert image0_str == ImageBlock.from_pil_image(pil_image)._value # get prompt message - lp = turn.output._generate_log.prompt + lp = turn.output._generate_log.prompt # type: ignore assert isinstance(lp, list) assert len(lp) == 1 diff --git a/test/backends/test_vllm.py b/test/backends/test_vllm.py index cfcda8c2..bc161d9b 100644 --- a/test/backends/test_vllm.py +++ b/test/backends/test_vllm.py @@ -6,15 +6,10 @@ from mellea import MelleaSession from mellea.backends.vllm import LocalVLLMBackend -from mellea.backends.types import ModelOption +from mellea.backends import ModelOption import mellea.backends.model_ids as model_ids -from mellea.stdlib.base import CBlock, ChatContext, SimpleContext -from mellea.stdlib.requirement import ( - LLMaJRequirement, - Requirement, - ValidationResult, - default_output_to_bool, -) +from mellea.core import CBlock +from mellea.stdlib.context import ChatContext, SimpleContext @pytest.fixture(scope="module") diff --git a/test/backends/test_vllm_tools.py b/test/backends/test_vllm_tools.py index 69c824b2..d195c1cb 100644 --- a/test/backends/test_vllm_tools.py +++ b/test/backends/test_vllm_tools.py @@ -1,19 +1,11 @@ import os -import pydantic import pytest -from typing_extensions import Annotated from mellea import MelleaSession from mellea.backends.vllm import LocalVLLMBackend -from mellea.backends.types import ModelOption +from mellea.backends import ModelOption import mellea.backends.model_ids as model_ids -from mellea.stdlib.base import CBlock, ChatContext -from mellea.stdlib.requirement import ( - LLMaJRequirement, - Requirement, - ValidationResult, - default_output_to_bool, -) +from mellea.stdlib.context import ChatContext @pytest.fixture(scope="module") diff --git a/test/backends/test_watsonx.py b/test/backends/test_watsonx.py index 08615973..f8811097 100644 --- a/test/backends/test_watsonx.py +++ b/test/backends/test_watsonx.py @@ -6,10 +6,11 @@ import pytest from mellea import MelleaSession -from mellea.backends.formatter import TemplateFormatter -from mellea.backends.types import ModelOption +from mellea.formatters import TemplateFormatter +from mellea.backends import ModelOption from mellea.backends.watsonx import WatsonxAIBackend -from mellea.stdlib.base import CBlock, ModelOutputThunk, ChatContext, SimpleContext +from mellea.core import CBlock, ModelOutputThunk +from mellea.stdlib.context import ChatContext, SimpleContext @pytest.fixture(scope="module") @@ -118,7 +119,7 @@ class Email(pydantic.BaseModel): ) print("Formatted output:") email = Email.model_validate_json( - output.value + output.value # type: ignore ) # this should succeed because the output should be JSON because we passed in a format= argument... print(email) diff --git a/test/stdlib_basics/test_base.py b/test/core/test_base.py similarity index 94% rename from test/stdlib_basics/test_base.py rename to test/core/test_base.py index c1184e64..30aa6f7f 100644 --- a/test/stdlib_basics/test_base.py +++ b/test/core/test_base.py @@ -1,7 +1,7 @@ from typing import Any import pytest -from mellea.stdlib.base import CBlock, Component, ModelOutputThunk -from mellea.stdlib.chat import Message +from mellea.core import CBlock, Component, ModelOutputThunk +from mellea.stdlib.components import Message def test_cblock(): diff --git a/test/stdlib_basics/test_component_typing.py b/test/core/test_component_typing.py similarity index 93% rename from test/stdlib_basics/test_component_typing.py rename to test/core/test_component_typing.py index d781ffa1..d6f10f7c 100644 --- a/test/stdlib_basics/test_component_typing.py +++ b/test/core/test_component_typing.py @@ -5,20 +5,19 @@ from mellea import start_session from mellea.backends.model_ids import IBM_GRANITE_4_MICRO_3B from mellea.backends.ollama import OllamaModelBackend -from mellea.stdlib.base import ( +from mellea.core import ( CBlock, - ChatContext, Component, ComponentParseError, Context, ModelOutputThunk, - SimpleContext, ) -from mellea.stdlib.chat import Message -from mellea.stdlib.instruction import Instruction -from mellea.stdlib.requirement import Requirement, ValidationResult -from mellea.stdlib.sampling.base import BaseSamplingStrategy -from mellea.stdlib.session import MelleaSession +from mellea.stdlib.context import ChatContext, SimpleContext +from mellea.stdlib.components import Message +from mellea.stdlib.components import Instruction +from mellea.core import Requirement, ValidationResult +from mellea.stdlib.sampling import BaseSamplingStrategy +from mellea import MelleaSession import mellea.stdlib.functional as mfuncs @@ -81,8 +80,8 @@ def test_mot_init_typing(): assert hasattr(mot, "__orig_class__"), ( f"mots are generics and should have this field" ) - assert get_args(mot.__orig_class__)[0] == float, ( - f"expected float, got {get_args(mot.__orig_class__)[0]} as mot type" + assert get_args(mot.__orig_class__)[0] == float, ( # type: ignore + f"expected float, got {get_args(mot.__orig_class__)[0]} as mot type" # type: ignore ) # type: ignore unknown_mot = ModelOutputThunk(value="2") diff --git a/test/stdlib_basics/test_model_output_thunk.py b/test/core/test_model_output_thunk.py similarity index 94% rename from test/stdlib_basics/test_model_output_thunk.py rename to test/core/test_model_output_thunk.py index 6f562812..562bbce9 100644 --- a/test/stdlib_basics/test_model_output_thunk.py +++ b/test/core/test_model_output_thunk.py @@ -1,8 +1,8 @@ import copy import pytest -from mellea.backends.types import ModelOption -from mellea.stdlib.base import ModelOutputThunk +from mellea.backends import ModelOption +from mellea.core import ModelOutputThunk from mellea.stdlib.session import MelleaSession, start_session diff --git a/test/test_formatter_baseclasses.py b/test/formatters/test_template_formatter.py similarity index 96% rename from test/test_formatter_baseclasses.py rename to test/formatters/test_template_formatter.py index 2ba232a4..c29fcc4a 100644 --- a/test/test_formatter_baseclasses.py +++ b/test/formatters/test_template_formatter.py @@ -2,21 +2,15 @@ import os import sys import tempfile -from typing import Any, List, Optional +from typing import List import pytest -from mellea.backends.formatter import TemplateFormatter +from mellea.formatters import TemplateFormatter from mellea.backends.model_ids import ModelIdentifier, IBM_GRANITE_3_2_8B -from mellea.stdlib.base import ( - CBlock, - Component, - ModelOutputThunk, - TemplateRepresentation, -) -from mellea.stdlib.chat import Message -from mellea.stdlib.instruction import Instruction -from mellea.stdlib.mobject import MObject +from mellea.core import CBlock, Component, ModelOutputThunk, TemplateRepresentation +from mellea.stdlib.components import Message, Instruction +from mellea.stdlib.components import MObject @pytest.fixture(scope="module") @@ -245,7 +239,7 @@ def test_custom_component_external_package(tf: TemplateFormatter): Ensures template loading works for custom components defined in other packages.""" new_component_content = """ -from mellea.stdlib.base import Component, TemplateRepresentation, ModelOutputThunk +from mellea.core import Component, TemplateRepresentation, ModelOutputThunk class NewComponent(Component[str]): def parts(self): raise NotImplementedError( diff --git a/test/stdlib_basics/test_event_loop_helper.py b/test/helpers/test_event_loop_helper.py similarity index 100% rename from test/stdlib_basics/test_event_loop_helper.py rename to test/helpers/test_event_loop_helper.py diff --git a/test/stdlib_basics/test_richdocument.py b/test/stdlib/components/docs/test_richdocument.py similarity index 97% rename from test/stdlib_basics/test_richdocument.py rename to test/stdlib/components/docs/test_richdocument.py index b7c3d921..ab713078 100644 --- a/test/stdlib_basics/test_richdocument.py +++ b/test/stdlib/components/docs/test_richdocument.py @@ -1,6 +1,6 @@ import os -from mellea.stdlib.base import TemplateRepresentation -from mellea.stdlib.docs.richdocument import RichDocument, Table +from mellea.core import TemplateRepresentation +from mellea.stdlib.components.docs import RichDocument, Table import mellea from docling_core.types.doc.document import DoclingDocument import tempfile diff --git a/test/stdlib_intrinsics/test_rag/test_rag.py b/test/stdlib/components/intrinsic/test_rag.py similarity index 95% rename from test/stdlib_intrinsics/test_rag/test_rag.py rename to test/stdlib/components/intrinsic/test_rag.py index ea812d1c..74efe6ef 100644 --- a/test/stdlib_intrinsics/test_rag/test_rag.py +++ b/test/stdlib/components/intrinsic/test_rag.py @@ -9,9 +9,10 @@ import torch from mellea.backends.huggingface import LocalHFBackend -from mellea.stdlib.base import ChatContext, Document -from mellea.stdlib.chat import Message -from mellea.stdlib.intrinsics import rag +from mellea.stdlib.components import Document +from mellea.stdlib.context import ChatContext +from mellea.stdlib.components import Message +from mellea.stdlib.components.intrinsic import rag DATA_ROOT = pathlib.Path(os.path.dirname(__file__)) / "testdata" """Location of data files for the tests in this file.""" @@ -147,12 +148,12 @@ def test_hallucination_detection(backend): # First call triggers adapter loading result = rag.flag_hallucinated_content(assistant_response, docs, context, backend) # pytest.approx() chokes on lists of records, so we do this complicated dance. - for r, e in zip(result, expected, strict=True): + for r, e in zip(result, expected, strict=True): # type: ignore assert pytest.approx(r, abs=2e-2) == e # Second call hits a different code path from the first one result = rag.flag_hallucinated_content(assistant_response, docs, context, backend) - for r, e in zip(result, expected, strict=True): + for r, e in zip(result, expected, strict=True): # type: ignore assert pytest.approx(r, abs=2e-2) == e diff --git a/test/stdlib_intrinsics/test_rag/testdata/input_json/answer_relevance.json b/test/stdlib/components/intrinsic/testdata/input_json/answer_relevance.json similarity index 100% rename from test/stdlib_intrinsics/test_rag/testdata/input_json/answer_relevance.json rename to test/stdlib/components/intrinsic/testdata/input_json/answer_relevance.json diff --git a/test/stdlib_intrinsics/test_rag/testdata/input_json/answerability.json b/test/stdlib/components/intrinsic/testdata/input_json/answerability.json similarity index 100% rename from test/stdlib_intrinsics/test_rag/testdata/input_json/answerability.json rename to test/stdlib/components/intrinsic/testdata/input_json/answerability.json diff --git a/test/stdlib_intrinsics/test_rag/testdata/input_json/citations.json b/test/stdlib/components/intrinsic/testdata/input_json/citations.json similarity index 100% rename from test/stdlib_intrinsics/test_rag/testdata/input_json/citations.json rename to test/stdlib/components/intrinsic/testdata/input_json/citations.json diff --git a/test/stdlib_intrinsics/test_rag/testdata/input_json/context_relevance.json b/test/stdlib/components/intrinsic/testdata/input_json/context_relevance.json similarity index 100% rename from test/stdlib_intrinsics/test_rag/testdata/input_json/context_relevance.json rename to test/stdlib/components/intrinsic/testdata/input_json/context_relevance.json diff --git a/test/stdlib_intrinsics/test_rag/testdata/input_json/hallucination_detection.json b/test/stdlib/components/intrinsic/testdata/input_json/hallucination_detection.json similarity index 100% rename from test/stdlib_intrinsics/test_rag/testdata/input_json/hallucination_detection.json rename to test/stdlib/components/intrinsic/testdata/input_json/hallucination_detection.json diff --git a/test/stdlib_intrinsics/test_rag/testdata/input_json/query_rewrite.json b/test/stdlib/components/intrinsic/testdata/input_json/query_rewrite.json similarity index 100% rename from test/stdlib_intrinsics/test_rag/testdata/input_json/query_rewrite.json rename to test/stdlib/components/intrinsic/testdata/input_json/query_rewrite.json diff --git a/test/stdlib_intrinsics/test_rag/testdata/output_json/citations.json b/test/stdlib/components/intrinsic/testdata/output_json/citations.json similarity index 100% rename from test/stdlib_intrinsics/test_rag/testdata/output_json/citations.json rename to test/stdlib/components/intrinsic/testdata/output_json/citations.json diff --git a/test/stdlib_intrinsics/test_rag/testdata/output_json/hallucination_detection.json b/test/stdlib/components/intrinsic/testdata/output_json/hallucination_detection.json similarity index 100% rename from test/stdlib_intrinsics/test_rag/testdata/output_json/hallucination_detection.json rename to test/stdlib/components/intrinsic/testdata/output_json/hallucination_detection.json diff --git a/test/stdlib_basics/test_chat.py b/test/stdlib/components/test_chat.py similarity index 72% rename from test/stdlib_basics/test_chat.py rename to test/stdlib/components/test_chat.py index 819b2796..1733e8e6 100644 --- a/test/stdlib_basics/test_chat.py +++ b/test/stdlib/components/test_chat.py @@ -1,7 +1,7 @@ import pytest -from mellea.backends.openai import OpenAIBackend -from mellea.stdlib.base import Document -from mellea.stdlib.chat import Message +from mellea.stdlib.components import Document +from mellea.stdlib.components import Message +from mellea.helpers import messages_to_docs def test_message_with_docs(): @@ -11,7 +11,7 @@ def test_message_with_docs(): assert msg._docs is not None assert doc in msg._docs - docs = OpenAIBackend.messages_to_docs([msg]) + docs = messages_to_docs([msg]) assert len(docs) == 1 assert docs[0]["text"] == doc.text assert docs[0]["title"] == doc.title diff --git a/test/stdlib_basics/test_genslot.py b/test/stdlib/components/test_genslot.py similarity index 95% rename from test/stdlib_basics/test_genslot.py rename to test/stdlib/components/test_genslot.py index e7e0bfb3..97df43af 100644 --- a/test/stdlib_basics/test_genslot.py +++ b/test/stdlib/components/test_genslot.py @@ -4,16 +4,17 @@ from mellea import generative, start_session from mellea.backends.model_ids import META_LLAMA_3_2_1B from mellea.backends.ollama import OllamaModelBackend -from mellea.stdlib.base import ChatContext, Context -from mellea.stdlib.genslot import ( +from mellea.core import Requirement +from mellea.stdlib.context import ChatContext, Context +from mellea.stdlib.components.genslot import ( AsyncGenerativeSlot, GenerativeSlot, PreconditionException, SyncGenerativeSlot, ) -from mellea.stdlib.requirement import Requirement, simple_validate -from mellea.stdlib.sampling.base import RejectionSamplingStrategy -from mellea.stdlib.session import MelleaSession +from mellea.stdlib.requirements import simple_validate +from mellea.stdlib.sampling import RejectionSamplingStrategy +from mellea import MelleaSession @pytest.fixture(scope="module") diff --git a/test/stdlib_basics/test_hello_world.py b/test/stdlib/components/test_hello_world.py similarity index 75% rename from test/stdlib_basics/test_hello_world.py rename to test/stdlib/components/test_hello_world.py index d6131a1d..c4d7a67a 100644 --- a/test/stdlib_basics/test_hello_world.py +++ b/test/stdlib/components/test_hello_world.py @@ -1,4 +1,4 @@ -from mellea.stdlib.instruction import Instruction +from mellea.stdlib.components import Instruction def test_empty_instr(): diff --git a/test/stdlib_basics/test_mify.py b/test/stdlib/components/test_mify.py similarity index 96% rename from test/stdlib_basics/test_mify.py rename to test/stdlib/components/test_mify.py index 00ead004..0587811a 100644 --- a/test/stdlib_basics/test_mify.py +++ b/test/stdlib/components/test_mify.py @@ -1,9 +1,9 @@ import pytest -from mellea.backends.formatter import TemplateFormatter -from mellea.stdlib.base import Component, TemplateRepresentation -from mellea.stdlib.mobject import Query, MObjectProtocol, MObject -from mellea.stdlib.mify import mify, MifiedProtocol +from mellea.formatters import TemplateFormatter +from mellea.core import Component, TemplateRepresentation +from mellea.stdlib.components.mobject import Query, MObjectProtocol, MObject +from mellea.stdlib.components.mify import mify, MifiedProtocol def test_protocol_adherence(): diff --git a/test/stdlib_basics/test_transform.py b/test/stdlib/components/test_transform.py similarity index 90% rename from test/stdlib_basics/test_transform.py rename to test/stdlib/components/test_transform.py index 54616e60..d3a604e4 100644 --- a/test/stdlib_basics/test_transform.py +++ b/test/stdlib/components/test_transform.py @@ -1,8 +1,8 @@ import pytest -from mellea.stdlib.base import TemplateRepresentation -from mellea.stdlib.docs.richdocument import TableTransform -from mellea.stdlib.mobject import MObject, Query, Transform +from mellea.core import TemplateRepresentation +from mellea.stdlib.components.docs import TableTransform +from mellea.stdlib.components import MObject, Query, Transform custom_mobject_description = "custom mobject description" diff --git a/test/stdlib_basics/test_reqlib_markdown.py b/test/stdlib/requirements/test_reqlib_markdown.py similarity index 77% rename from test/stdlib_basics/test_reqlib_markdown.py rename to test/stdlib/requirements/test_reqlib_markdown.py index b25ac0ef..b7ab92a0 100644 --- a/test/stdlib_basics/test_reqlib_markdown.py +++ b/test/stdlib/requirements/test_reqlib_markdown.py @@ -1,12 +1,13 @@ import pytest -from mellea.stdlib.base import CBlock, ModelOutputThunk, Context, ChatContext -from mellea.stdlib.reqlib.md import ( +from mellea.core import CBlock, ModelOutputThunk +from mellea.stdlib.context import Context, ChatContext +from mellea.stdlib.requirements import ( is_markdown_list, is_markdown_table, as_markdown_list, ) -from mellea.stdlib.requirement import default_output_to_bool +from mellea.core import default_output_to_bool def from_model(s: str) -> Context: @@ -45,15 +46,15 @@ def from_model(s: str) -> Context: def test_markdown_list(): - assert is_markdown_list.validate(None, MARKDOWN_LIST_CTX) - assert len(as_markdown_list(MARKDOWN_LIST_CTX)) == 3 - assert len(as_markdown_list(MARKDOWN_OL_LIST_CTX)) == 4 - assert type(as_markdown_list(MARKDOWN_OL_LIST_CTX)[0]) is str - assert is_markdown_list.validate(None, MARKDOWN_OL_LIST_CTX) + assert is_markdown_list.validate(None, MARKDOWN_LIST_CTX) # type: ignore + assert len(as_markdown_list(MARKDOWN_LIST_CTX)) == 3 # type: ignore + assert len(as_markdown_list(MARKDOWN_OL_LIST_CTX)) == 4 # type: ignore + assert type(as_markdown_list(MARKDOWN_OL_LIST_CTX)[0]) is str # type: ignore + assert is_markdown_list.validate(None, MARKDOWN_OL_LIST_CTX) # type: ignore def test_markdown_table(): - assert is_markdown_table.validate(None, MARKDOWN_TABLE_CONTEXT) + assert is_markdown_table.validate(None, MARKDOWN_TABLE_CONTEXT) # type: ignore def test_default_output_to_bool_yes(): diff --git a/test/stdlib_basics/test_reqlib_python.py b/test/stdlib/requirements/test_reqlib_python.py similarity index 89% rename from test/stdlib_basics/test_reqlib_python.py rename to test/stdlib/requirements/test_reqlib_python.py index b3f9211a..403fa79e 100644 --- a/test/stdlib_basics/test_reqlib_python.py +++ b/test/stdlib/requirements/test_reqlib_python.py @@ -16,17 +16,17 @@ except ImportError: _llm_sandbox_available = False -from mellea.stdlib.base import Context -from mellea.stdlib.reqlib.python import ( +from mellea.core import Context, ModelOutputThunk +from mellea.stdlib.requirements.python_reqs import ( PythonExecutionReq, _has_python_code_listing, _python_executes_without_error, ) +from mellea.stdlib.context import ChatContext def from_model(content: str) -> Context: """Helper to create context from model output.""" - from mellea.stdlib.base import ChatContext, ModelOutputThunk ctx = ChatContext() ctx = ctx.add(ModelOutputThunk(value=content)) @@ -101,21 +101,21 @@ def test_has_python_code_listing_valid(): """Test extraction of valid Python code.""" result = _has_python_code_listing(VALID_PYTHON_CTX) assert result.as_bool() is True - assert "def hello_world" in result.reason + assert "def hello_world" in result.reason # type: ignore def test_has_python_code_listing_no_code(): """Test handling when no Python code is present.""" result = _has_python_code_listing(NO_PYTHON_CTX) assert result.as_bool() is False - assert "No Python code blocks found" in result.reason + assert "No Python code blocks found" in result.reason # type: ignore def test_has_python_code_listing_simple(): """Test extraction of simple Python code.""" result = _has_python_code_listing(PYTHON_SIMPLE_CTX) assert result.as_bool() is True - assert "print" in result.reason + assert "print" in result.reason # type: ignore # endregion @@ -161,7 +161,7 @@ def test_unsafe_execution_runtime_error(): req = PythonExecutionReq(allow_unsafe_execution=True, timeout=5) result = req.validation_fn(RUNTIME_ERROR_CTX) assert result.as_bool() is False - assert "error" in result.reason.lower() + assert "error" in result.reason.lower() # type: ignore def test_unsafe_execution_timeout(): @@ -169,7 +169,7 @@ def test_unsafe_execution_timeout(): req = PythonExecutionReq(allow_unsafe_execution=True, timeout=1) result = req.validation_fn(PYTHON_INFINITE_LOOP_CTX) assert result.as_bool() is False - assert "timed out" in result.reason.lower() + assert "timed out" in result.reason.lower() # type: ignore def test_unsafe_execution_syntax_error(): @@ -189,7 +189,7 @@ def test_import_restrictions_block_forbidden(): req = PythonExecutionReq(allow_unsafe_execution=True, allowed_imports=["os", "sys"]) result = req.validation_fn(PYTHON_WITH_FORBIDDEN_IMPORTS_CTX) assert result.as_bool() is False - assert "Unauthorized imports" in result.reason + assert "Unauthorized imports" in result.reason # type: ignore def test_import_restrictions_allow_permitted(): @@ -206,7 +206,7 @@ def test_import_restrictions_with_safe_mode(): req = PythonExecutionReq(allowed_imports=["os", "sys"]) result = req.validation_fn(PYTHON_WITH_FORBIDDEN_IMPORTS_CTX) assert result.as_bool() is False - assert "Unauthorized imports" in result.reason + assert "Unauthorized imports" in result.reason # type: ignore # endregion @@ -265,15 +265,15 @@ def test_sandbox_without_llm_sandbox_installed(): def test_description_updates_based_on_mode(): """Test that requirement description reflects execution mode.""" safe_req = PythonExecutionReq() - assert "validation only" in safe_req.description + assert "validation only" in safe_req.description # type: ignore unsafe_req = PythonExecutionReq(allow_unsafe_execution=True, timeout=5) - assert "unsafe execution" in unsafe_req.description - assert "timeout: 5s" in unsafe_req.description + assert "unsafe execution" in unsafe_req.description # type: ignore + assert "timeout: 5s" in unsafe_req.description # type: ignore sandbox_req = PythonExecutionReq(use_sandbox=True, timeout=10) - assert "sandbox execution" in sandbox_req.description - assert "timeout: 10s" in sandbox_req.description + assert "sandbox execution" in sandbox_req.description # type: ignore + assert "timeout: 10s" in sandbox_req.description # type: ignore def test_parameter_combinations(): @@ -322,7 +322,7 @@ def test_no_code_extraction(): req = PythonExecutionReq() result = req.validation_fn(NO_PYTHON_CTX) assert result.as_bool() is False - assert "Could not extract Python code" in result.reason + assert "Could not extract Python code" in result.reason # type: ignore # endregion diff --git a/test/stdlib_basics/test_reqlib_tools.py b/test/stdlib/requirements/test_reqlib_tools.py similarity index 78% rename from test/stdlib_basics/test_reqlib_tools.py rename to test/stdlib/requirements/test_reqlib_tools.py index 9e92d890..553cc0c9 100644 --- a/test/stdlib_basics/test_reqlib_tools.py +++ b/test/stdlib/requirements/test_reqlib_tools.py @@ -1,5 +1,5 @@ import pytest -from mellea.stdlib.reqlib.tools import _name2str +from mellea.stdlib.requirements.tool_reqs import _name2str def test_name2str(): diff --git a/test/stdlib_basics/test_requirement.py b/test/stdlib/requirements/test_requirement.py similarity index 87% rename from test/stdlib_basics/test_requirement.py rename to test/stdlib/requirements/test_requirement.py index 594fe289..5db655fe 100644 --- a/test/stdlib_basics/test_requirement.py +++ b/test/stdlib/requirements/test_requirement.py @@ -1,7 +1,7 @@ -import asyncio import pytest -from mellea.stdlib.base import ChatContext, ModelOutputThunk -from mellea.stdlib.requirement import LLMaJRequirement, Requirement, simple_validate +from mellea.stdlib.context import ChatContext +from mellea.core import ModelOutputThunk, Requirement +from mellea.stdlib.requirements import LLMaJRequirement, simple_validate from mellea.stdlib.session import start_session ctx = ChatContext() @@ -51,7 +51,7 @@ def test_simple_validate_bool_string(): def test_simple_validate_invalid(): - validation_func = simple_validate(lambda x: None) + validation_func = simple_validate(lambda x: None) # type: ignore with pytest.raises(ValueError): val_result = validation_func(ctx) diff --git a/test/stdlib_basics/test_majority_voting.py b/test/stdlib/sampling/test_majority_voting.py similarity index 94% rename from test/stdlib_basics/test_majority_voting.py rename to test/stdlib/sampling/test_majority_voting.py index 56cc3389..7bd9b9bd 100644 --- a/test/stdlib_basics/test_majority_voting.py +++ b/test/stdlib/sampling/test_majority_voting.py @@ -1,13 +1,13 @@ from mellea.backends import ModelOption from mellea import start_session, MelleaSession -from mellea.stdlib.requirement import check, req, simple_validate +from mellea.stdlib.requirements import check, req, simple_validate from mellea.stdlib.sampling.majority_voting import ( MBRDRougeLStrategy, MajorityVotingStrategyForMath, ) import pytest -from mellea.stdlib.sampling.types import SamplingResult +from mellea.core import SamplingResult @pytest.fixture(scope="module") diff --git a/test/stdlib_basics/test_sampling_ctx.py b/test/stdlib/sampling/test_sampling_ctx.py similarity index 89% rename from test/stdlib_basics/test_sampling_ctx.py rename to test/stdlib/sampling/test_sampling_ctx.py index 362730d6..c64b23d4 100644 --- a/test/stdlib_basics/test_sampling_ctx.py +++ b/test/stdlib/sampling/test_sampling_ctx.py @@ -1,13 +1,9 @@ import pytest from mellea import start_session from mellea.backends import ModelOption -from mellea.stdlib.base import ChatContext, ModelOutputThunk, Context -from mellea.stdlib.requirement import Requirement -from mellea.stdlib.sampling import ( - MultiTurnStrategy, - RejectionSamplingStrategy, - SamplingResult, -) +from mellea.core import ModelOutputThunk, Context, Requirement, SamplingResult +from mellea.stdlib.context import ChatContext +from mellea.stdlib.sampling import MultiTurnStrategy, RejectionSamplingStrategy class TestSamplingCtxCase: @@ -46,7 +42,7 @@ def test_ctx_for_rejection_sampling(self): assert len(self.m.ctx.as_list()) == 2, ( "there should only be a message and a response in the ctx." ) - assert len(self.m.last_prompt()) == 1, ( + assert len(self.m.last_prompt()) == 1, ( # type: ignore "Last prompt should only have only one instruction inside - independent of sampling iterations." ) @@ -55,7 +51,7 @@ def test_ctx_for_rejection_sampling(self): # the correct actions / results in it. assert isinstance(val_res.context, Context) assert isinstance(val_res.thunk, ModelOutputThunk) - assert isinstance(val_res.context.previous_node.node_data, Requirement) + assert isinstance(val_res.context.previous_node.node_data, Requirement) # type: ignore assert val_res.context.node_data is val_res.thunk def test_ctx_for_multiturn(self): @@ -75,7 +71,7 @@ def test_ctx_for_multiturn(self): assert len(self.m.ctx.as_list()) >= 2, ( "there should be at least a message and a response in the ctx; more if the first result failed validation" ) - assert len(self.m.last_prompt()) == len(res.sample_generations) * 2 - 1, ( + assert len(self.m.last_prompt()) == len(res.sample_generations) * 2 - 1, ( # type: ignore "For n sampling iterations there should be 2n-1 prompt conversation elements in the last prompt." ) diff --git a/test/stdlib_basics/test_think_budget_forcing.py b/test/stdlib/sampling/test_think_budget_forcing.py similarity index 93% rename from test/stdlib_basics/test_think_budget_forcing.py rename to test/stdlib/sampling/test_think_budget_forcing.py index ee53a231..747849f1 100644 --- a/test/stdlib_basics/test_think_budget_forcing.py +++ b/test/stdlib/sampling/test_think_budget_forcing.py @@ -5,7 +5,7 @@ from mellea import MelleaSession, start_session from mellea.backends import ModelOption from mellea.backends.model_ids import OPENAI_GPT_OSS_20B -from mellea.stdlib.base import CBlock +from mellea.core import CBlock from mellea.stdlib.sampling.budget_forcing import BudgetForcingSamplingStrategy MODEL_ID = OPENAI_GPT_OSS_20B @@ -45,7 +45,7 @@ def test_think_big(m_session: MelleaSession, gh_run: int): answer_suffix="The final answer is:", requirements=None, ) - result = m_session.instruct(action, strategy=strategy) + result = m_session.instruct(action, strategy=strategy) # type: ignore print("\n******\nThink big:") print(str(result)) @@ -70,7 +70,7 @@ def test_think_little(m_session: MelleaSession, gh_run: int): answer_suffix="The final answer is: \\boxed{", requirements=None, ) - result = m_session.instruct(action, strategy=strategy) + result = m_session.instruct(action, strategy=strategy) # type: ignore print("\n******\nThink little:") print(str(result)) diff --git a/test/stdlib_basics/test_base_context.py b/test/stdlib/test_base_context.py similarity index 90% rename from test/stdlib_basics/test_base_context.py rename to test/stdlib/test_base_context.py index 0c1a5620..698a9240 100644 --- a/test/stdlib_basics/test_base_context.py +++ b/test/stdlib/test_base_context.py @@ -1,6 +1,7 @@ import pytest -from mellea.stdlib.base import Context, CBlock, SimpleContext, ChatContext +from mellea.core import Context, CBlock +from mellea.stdlib.context import SimpleContext, ChatContext def context_construction(cls: type[Context]): @@ -38,7 +39,7 @@ def test_render_view_for_simple_context(): for i in range(5): ctx = ctx.add(CBlock(f"a {i}")) assert len(ctx.as_list()) == 5, "Adding 5 items to context should result in 5 items" - assert len(ctx.view_for_generation()) == 0, ( + assert len(ctx.view_for_generation()) == 0, ( # type: ignore "Render size should be 0 -- NO HISTORY for SimpleContext" ) @@ -48,7 +49,7 @@ def test_render_view_for_chat_context(): for i in range(5): ctx = ctx.add(CBlock(f"a {i}")) assert len(ctx.as_list()) == 5, "Adding 5 items to context should result in 5 items" - assert len(ctx.view_for_generation()) == 3, "Render size should be 3" + assert len(ctx.view_for_generation()) == 3, "Render size should be 3" # type: ignore def test_actions_for_available_tools(): diff --git a/test/stdlib_basics/test_chat_view.py b/test/stdlib/test_chat_view.py similarity index 91% rename from test/stdlib_basics/test_chat_view.py rename to test/stdlib/test_chat_view.py index cbc9c7a2..9b0ff93d 100644 --- a/test/stdlib_basics/test_chat_view.py +++ b/test/stdlib/test_chat_view.py @@ -1,7 +1,7 @@ import pytest -from mellea.stdlib.base import ChatContext, ModelOutputThunk -from mellea.stdlib.chat import Message, as_chat_history +from mellea.stdlib.context import ChatContext +from mellea.stdlib.components import Message, as_chat_history from mellea.stdlib.session import start_session diff --git a/test/stdlib_basics/test_functional.py b/test/stdlib/test_functional.py similarity index 90% rename from test/stdlib_basics/test_functional.py rename to test/stdlib/test_functional.py index 4dbfb9e0..95f8add5 100644 --- a/test/stdlib_basics/test_functional.py +++ b/test/stdlib/test_functional.py @@ -1,10 +1,10 @@ import pytest -from mellea.backends.types import ModelOption -from mellea.stdlib.base import ModelOutputThunk -from mellea.stdlib.chat import Message +from mellea.backends import ModelOption +from mellea.core import ModelOutputThunk +from mellea.stdlib.components import Message from mellea.stdlib.functional import instruct, aact, avalidate, ainstruct -from mellea.stdlib.requirement import req +from mellea.stdlib.requirements import req from mellea.stdlib.session import start_session diff --git a/test/stdlib_basics/test_session.py b/test/stdlib/test_session.py similarity index 89% rename from test/stdlib_basics/test_session.py rename to test/stdlib/test_session.py index 6694246c..ab644877 100644 --- a/test/stdlib_basics/test_session.py +++ b/test/stdlib/test_session.py @@ -3,10 +3,10 @@ import pytest -from mellea.backends.ollama import OllamaModelBackend -from mellea.backends.types import ModelOption -from mellea.stdlib.base import ChatContext, ModelOutputThunk -from mellea.stdlib.chat import Message +from mellea.backends import ModelOption +from mellea.stdlib.context import ChatContext +from mellea.core import ModelOutputThunk +from mellea.stdlib.components import Message from mellea.stdlib.session import start_session, MelleaSession @@ -88,11 +88,11 @@ async def test_async_await_with_chat_context(m_session): ctx = m_session.ctx for i in range(len(history)): - assert ctx.node_data is history[i] - ctx = ctx.previous_node + assert ctx.node_data is history[i] # type: ignore + ctx = ctx.previous_node # type: ignore # Ensure we made it back to the root. - assert ctx.is_root_node == True + assert ctx.is_root_node == True # type: ignore async def test_async_without_waiting_with_chat_context(m_session): @@ -105,7 +105,7 @@ async def test_async_without_waiting_with_chat_context(m_session): _, _ = await asyncio.gather(co2, co1) ctx = m_session.ctx - assert len(ctx.view_for_generation()) == 2 + assert len(ctx.view_for_generation()) == 2 # type: ignore def test_session_copy_with_context_ops(m_session): @@ -135,7 +135,7 @@ def test_session_copy_with_context_ops(m_session): class TestPowerup: - def hello(m: MelleaSession): + def hello(m: MelleaSession): # type: ignore return "hello" diff --git a/test/stdlib_basics/test_spans.py b/test/stdlib/test_spans.py similarity index 86% rename from test/stdlib_basics/test_spans.py rename to test/stdlib/test_spans.py index 1934dd47..0ddb4bc1 100644 --- a/test/stdlib_basics/test_spans.py +++ b/test/stdlib/test_spans.py @@ -1,12 +1,8 @@ -import asyncio -import os - import pytest -from mellea.backends.ollama import OllamaModelBackend -from mellea.backends.types import ModelOption -from mellea.stdlib.base import ChatContext, ModelOutputThunk, CBlock, SimpleComponent -from mellea.stdlib.chat import Message +from mellea.backends import ModelOption +from mellea.core import CBlock +from mellea.stdlib.components import SimpleComponent from mellea.stdlib.session import start_session, MelleaSession from mellea.backends.model_ids import IBM_GRANITE_3_3_8B from mellea.backends.huggingface import LocalHFBackend @@ -43,7 +39,7 @@ async def test_lazy_spans(m_session): @pytest.mark.qualitative async def test_kv(m_session): m: MelleaSession = m_session - backend, ctx = m.backend, m.ctx + backend, ctx = m.backend, m.ctx # type: ignore ctx = ctx.add( SimpleComponent( diff --git a/test/stdlib_basics/test_contextual_session.py b/test/stdlib_basics/test_contextual_session.py deleted file mode 100644 index a401f117..00000000 --- a/test/stdlib_basics/test_contextual_session.py +++ /dev/null @@ -1,227 +0,0 @@ -# TODO: needs to be rewritten -# from typing import Literal -# -# import pytest -# -# from mellea import chat, generative, instruct, query, start_session, transform, validate -# from mellea.backends.model_ids import IBM_GRANITE_3_3_8B, META_LLAMA_3_2_1B -# from mellea.stdlib.base import ModelOutputThunk -# from mellea.stdlib.mify import MifiedProtocol, mify -# from mellea.stdlib.requirement import req -# from mellea.stdlib.session import MelleaSession, get_session -# -# -# @pytest.fixture(scope="module") -# def model_id(gh_run: int): -# if gh_run == 1: -# return META_LLAMA_3_2_1B -# else: -# return IBM_GRANITE_3_3_8B -# -# -# @generative -# def classify_sentiment(text: str) -> Literal["positive", "negative"]: ... -# -# -# @generative -# def generate_summary(text: str) -> str: ... -# -# -# @mify(fields_include={"name", "age"}) -# class TestPerson: -# def __init__(self, name: str, age: int): -# self.name = name -# self.age = age -# -# def get_info(self) -> str: -# """Get person information.""" -# return f"{self.name} is {self.age} years old" -# -# -# def test_basic_contextual_session(model_id): -# """Test basic contextual session usage with convenience functions.""" -# with start_session(model_id=model_id): -# # Test instruct -# result = instruct("Say hello") -# assert isinstance(result, ModelOutputThunk) -# assert result.value is not None -# -# # Test that we can get the session -# current_session = get_session() -# assert isinstance(current_session, MelleaSession) -# -# -# def test_no_active_session_error(): -# """Test error handling when no session is active.""" -# with pytest.raises(RuntimeError, match="No active session found"): -# get_session() -# -# with pytest.raises(RuntimeError, match="No active session found"): -# instruct("test") -# -# with pytest.raises(RuntimeError, match="No active session found"): -# chat("test") -# -# @pytest.mark.qualitative -# def test_generative_with_contextual_session(model_id): -# """Test generative slots work with contextual sessions.""" -# with start_session(model_id=model_id): -# # Test without explicit session parameter -# result = classify_sentiment(text="I love this!") -# assert result in ["positive", "negative"] -# -# # Test with summary generation -# summary = generate_summary(text="A short text about something interesting.") -# assert isinstance(summary, str) -# assert len(summary) > 0 -# -# @pytest.mark.qualitative -# def test_generative_backward_compatibility(model_id): -# """Test that generative slots still work with explicit session parameter.""" -# with start_session(model_id=model_id) as m: -# # Test old pattern still works -# result = classify_sentiment(m, text="I love this!") -# assert result in ["positive", "negative"] -# -# -# def test_mify_with_contextual_session(model_id): -# """Test mify functionality with contextual sessions.""" -# with start_session(model_id=model_id): -# person = TestPerson("Alice", 30) -# assert isinstance(person, MifiedProtocol) -# -# # Test query functionality -# query_result = query(person, "What is this person's name?") -# assert isinstance(query_result, ModelOutputThunk) -# -# # Test transform functionality -# transform_result = transform(person, "Make this person 5 years older") -# # Transform can return either ModelOutputThunk or the tool output when tools are called -# assert transform_result is not None -# -# -# def test_nested_sessions(model_id): -# """Test nested sessions behavior.""" -# with start_session(model_id=model_id) as outer_session: -# outer_result = instruct("outer session test") -# assert isinstance(outer_result, ModelOutputThunk) -# -# with start_session(model_id=model_id) as inner_session: -# # Inner session should be active -# current_session = get_session() -# assert current_session is inner_session -# -# inner_result = instruct("inner session test") -# assert isinstance(inner_result, ModelOutputThunk) -# -# # After inner session exits, outer should be active again -# current_session = get_session() -# assert current_session is outer_session -# -# -# def test_session_cleanup(model_id): -# """Test session cleanup after context exit.""" -# session_ref = None -# with start_session(model_id=model_id) as m: -# session_ref = m -# instruct("test during session") -# -# # After exiting context, no session should be active -# with pytest.raises(RuntimeError, match="No active session found"): -# get_session() -# -# # Session should have been cleaned up -# assert hasattr(session_ref, "ctx") -# -# -# def test_all_convenience_functions(model_id): -# """Test all convenience functions work within contextual session.""" -# with start_session(model_id=model_id): -# # Test instruct -# instruct_result = instruct("Generate a greeting") -# assert isinstance(instruct_result, ModelOutputThunk) -# -# # Test chat -# chat_result = chat("Hello there") -# assert hasattr(chat_result, "content") -# -# # Test validate -# validation = validate([req("The response should be positive")]) -# assert validation is not None -# -# # Test query with a mified object -# test_person = TestPerson("Test", 42) -# query_result = query(test_person, "What is the name?") -# assert isinstance(query_result, ModelOutputThunk) -# -# # Test transform with a mified object -# transform_result = transform(test_person, "Double the age") -# assert transform_result is not None -# -# -# def test_session_with_parameters(model_id): -# """Test contextual session with custom parameters.""" -# with start_session(backend_name="ollama", model_id=model_id) as m: -# result = instruct("test with parameters") -# assert isinstance(result, ModelOutputThunk) -# assert isinstance(m, MelleaSession) -# -# -# def test_multiple_sequential_sessions(model_id): -# """Test multiple sequential contextual sessions.""" -# # First session -# with start_session(model_id=model_id): -# result1 = instruct("first session") -# assert isinstance(result1, ModelOutputThunk) -# -# # Ensure no session is active between contexts -# with pytest.raises(RuntimeError, match="No active session found"): -# get_session() -# -# # Second session -# with start_session(model_id=model_id): -# result2 = instruct("second session") -# assert isinstance(result2, ModelOutputThunk) -# -# -# def test_contextual_session_with_mified_object_methods(model_id): -# """Test that mified objects work properly within contextual sessions.""" -# with start_session(model_id=model_id): -# person = TestPerson("Bob", 25) -# -# # Test that mified object methods work -# query_obj = person.get_query_object("What's the age?") -# assert query_obj is not None -# -# transform_obj = person.get_transform_object("Make older") -# assert transform_obj is not None -# -# # Test format_for_llm -# llm_format = person.format_for_llm() -# assert llm_format is not None -# assert hasattr(llm_format, "args") -# -# -# def test_session_methods_with_mified_objects(model_id): -# """Test using session query/transform methods with mified objects.""" -# with start_session(model_id=model_id) as m: -# person = TestPerson("Charlie", 35) -# -# # Test session query method -# query_result = m.query(person, "What is this person's age?") -# assert isinstance(query_result, ModelOutputThunk) -# -# # Test session transform method -# transform_result = m.transform(person, "Make this person younger") -# # Transform can return either ModelOutputThunk or tool output when tools are called -# assert transform_result is not None -# -# # Verify mified objects have query/transform object creation methods -# assert hasattr(person, "get_query_object") -# assert hasattr(person, "get_transform_object") -# assert hasattr(person, "_query_type") -# assert hasattr(person, "_transform_type") -# -# -# if __name__ == "__main__": -# pytest.main([__file__])