diff --git a/pr.md b/pr.md new file mode 100644 index 0000000000..3be3301f80 --- /dev/null +++ b/pr.md @@ -0,0 +1,40 @@ +# PR TODOs + +- [ ] Decide on the feedback schema that is appropriate for actions. There are concepts that seem to overlap (ex. config ~ self, score ~ action, etc...) +- [ ] Firm up the Action Spec (seems pretty good IMO) +- [ ] Python API for creating objects is pretty bad - especially when we want to reference other objects... this is not clean right now. +- [ ] Create the concept of a filter action (needs to have "enabled" attribute) +- [ ] UI Elements + - [ ] Configured Actions + - [ ] List + - [ ] Create + - [ ] Edit + - [ ] Delete (delete objects) + - [ ] View? + - [ ] See Mappings + - [ ] Mappings + - [ ] List + - [ ] Create + - [ ] Edit + - [ ] Delete (delete objects) + - [ ] View? + - [ ] Link to configured action + - [ ] See Actioned Calls (listing of feedback) + - [ ] Filter Action + - [ ] List + - [ ] Create + - [ ] Edit + - [ ] Disable / Pause + - [ ] Delete (delete objects) + - [ ] View? + - [ ] Link to mapping + - [ ] See "live feed" of applicable calls + - [ ] Call Table + - [ ] Action Result Column(s) + - [ ] "Fill" Button (or create filter action - basically a single or live version) + - [ ] Call View + - [ ] Action Results + - [ ] Single Execution Button (would be nice to have smart mapping) + - [ ] OpVersion View + - [ ] View associated mappings +- [ ] Create additional \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index bbba381f04..010dac5582 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -477,6 +477,7 @@ def create_client(request) -> weave_init.InitializedClient: server: tsi.TraceServerInterface entity = "shawn" project = "test-project" + weave_server_flag = "clickhouse" if weave_server_flag == "sqlite": sqlite_server = sqlite_trace_server.SqliteTraceServer( "file::memory:?cache=shared" @@ -500,7 +501,7 @@ def create_client(request) -> weave_init.InitializedClient: ) server = remote_server elif weave_server_flag == ("prod"): - inited_client = weave_init.init_weave("dev_testing") + inited_client = weave_init.init_weave("dev_testing_evals_3") if inited_client is None: client = TestOnlyFlushingWeaveClient( diff --git a/tests/trace/demo.ipynb b/tests/trace/demo.ipynb new file mode 100644 index 0000000000..2c199d749a --- /dev/null +++ b/tests/trace/demo.ipynb @@ -0,0 +1,182 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Logged in as Weights & Biases user: timssweeney.\n", + "View Weave data at https://wandb.ai/timssweeney/action_test_4/weave\n" + ] + } + ], + "source": [ + "import os\n", + "\n", + "import weave\n", + "\n", + "os.environ[\"WF_TRACE_SERVER_URL\"] = \"http://127.0.01:6345\"\n", + "\n", + "client = weave.init(\"action_test_4\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# from weave.trace import autopatch, weave_init\n", + "# from weave.trace_server import clickhouse_trace_server_batched\n", + "\n", + "# ch_server = clickhouse_trace_server_batched.ClickHouseTraceServer.from_env()\n", + "# inited_client = weave_init.InitializedClient(client)\n", + "# autopatch.autopatch()\n" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "🍩 https://wandb.ai/timssweeney/action_test_4/r/call/0192b29c-96ba-7a32-8228-74a274b25a69\n", + "🍩 https://wandb.ai/timssweeney/action_test_4/r/call/0192b29c-96bd-78f2-8966-473c36799e10\n", + "🍩 https://wandb.ai/timssweeney/action_test_4/r/call/0192b29c-9827-71f0-9997-199922a0b5a5\n", + "🍩 https://wandb.ai/timssweeney/action_test_4/r/call/0192b29c-999b-7762-93a5-f70678bba80b\n", + "🍩 https://wandb.ai/timssweeney/action_test_4/r/call/0192b29c-9ae2-70c1-a03b-ee5f915e48fe\n", + "[Call(_op_name=, trace_id='0192b29c-96ba-7a32-8228-74b76ac97678', project_id='timssweeney/action_test_4', parent_id=None, inputs={'user_input': 'My name is Tim.'}, id='0192b29c-96ba-7a32-8228-74a274b25a69', output=\"i don't know!\", exception=None, summary={}, display_name=None, attributes=AttributesDict({'weave': {'client_version': '0.51.18-dev0', 'source': 'python-sdk', 'os_name': 'Darwin', 'os_version': 'Darwin Kernel Version 23.2.0: Wed Nov 15 21:53:18 PST 2023; root:xnu-10002.61.3~2/RELEASE_ARM64_T6000', 'os_release': '23.2.0', 'sys_version': '3.10.8 (main, Dec 5 2022, 18:10:41) [Clang 14.0.0 (clang-1400.0.29.202)]'}}), started_at=None, ended_at=datetime.datetime(2024, 10, 22, 5, 4, 19, 642740, tzinfo=datetime.timezone.utc), deleted_at=None, _children=[], _feedback=None), Call(_op_name=, trace_id='0192b29c-96bd-78f2-8966-47406b46d0d5', project_id='timssweeney/action_test_4', parent_id=None, inputs={'user_input': 'My name is Scott.'}, id='0192b29c-96bd-78f2-8966-473c36799e10', output='Scott', exception=None, summary={'usage': {'gpt-3.5-turbo-0125': {'requests': 1, 'completion_tokens': 1, 'prompt_tokens': 35, 'total_tokens': 36, 'completion_tokens_details': {'reasoning_tokens': 0}, 'prompt_tokens_details': {'cached_tokens': 0}}}}, display_name=None, attributes=AttributesDict({'weave': {'client_version': '0.51.18-dev0', 'source': 'python-sdk', 'os_name': 'Darwin', 'os_version': 'Darwin Kernel Version 23.2.0: Wed Nov 15 21:53:18 PST 2023; root:xnu-10002.61.3~2/RELEASE_ARM64_T6000', 'os_release': '23.2.0', 'sys_version': '3.10.8 (main, Dec 5 2022, 18:10:41) [Clang 14.0.0 (clang-1400.0.29.202)]'}}), started_at=None, ended_at=datetime.datetime(2024, 10, 22, 5, 4, 20, 6862, tzinfo=datetime.timezone.utc), deleted_at=None, _children=[Call(_op_name=, trace_id='0192b29c-96bd-78f2-8966-47406b46d0d5', project_id='timssweeney/action_test_4', parent_id='0192b29c-96bd-78f2-8966-473c36799e10', inputs={'self': , 'messages': [{'role': 'system', 'content': 'Extract the name from the user input. If there is no name, return an empty string.'}, {'role': 'user', 'content': 'My name is Scott.'}], 'model': 'gpt-3.5-turbo', 'max_tokens': 64, 'temperature': 0.0, 'top_p': 1}, id='0192b29c-96be-7f13-8fe1-0e33db96732c', output=ObjectRecord({'id': 'chatcmpl-AL1M4ja83J341nQzcFOlzgmbEd4Bt', 'choices': [ObjectRecord({'finish_reason': 'stop', 'index': 0, 'logprobs': None, 'message': ObjectRecord({'content': 'Scott', 'refusal': None, 'role': 'assistant', 'audio': None, 'function_call': None, 'tool_calls': None, '_class_name': 'ChatCompletionMessage', '_bases': ['BaseModel', 'BaseModel']}), '_class_name': 'Choice', '_bases': ['BaseModel', 'BaseModel']})], 'created': 1729573460, 'model': 'gpt-3.5-turbo-0125', 'object': 'chat.completion', 'service_tier': None, 'system_fingerprint': None, 'usage': ObjectRecord({'completion_tokens': 1, 'prompt_tokens': 35, 'total_tokens': 36, 'completion_tokens_details': ObjectRecord({'audio_tokens': None, 'reasoning_tokens': 0, '_class_name': 'CompletionTokensDetails', '_bases': ['BaseModel', 'BaseModel']}), 'prompt_tokens_details': ObjectRecord({'audio_tokens': None, 'cached_tokens': 0, '_class_name': 'PromptTokensDetails', '_bases': ['BaseModel', 'BaseModel']}), '_class_name': 'CompletionUsage', '_bases': ['BaseModel', 'BaseModel']}), '_class_name': 'ChatCompletion', '_bases': ['BaseModel', 'BaseModel']}), exception=None, summary={'usage': {'gpt-3.5-turbo-0125': {'requests': 1, 'completion_tokens': 1, 'prompt_tokens': 35, 'total_tokens': 36, 'completion_tokens_details': {'reasoning_tokens': 0}, 'prompt_tokens_details': {'cached_tokens': 0}}}}, display_name=None, attributes=AttributesDict({'weave': {'client_version': '0.51.18-dev0', 'source': 'python-sdk', 'os_name': 'Darwin', 'os_version': 'Darwin Kernel Version 23.2.0: Wed Nov 15 21:53:18 PST 2023; root:xnu-10002.61.3~2/RELEASE_ARM64_T6000', 'os_release': '23.2.0', 'sys_version': '3.10.8 (main, Dec 5 2022, 18:10:41) [Clang 14.0.0 (clang-1400.0.29.202)]'}}), started_at=None, ended_at=datetime.datetime(2024, 10, 22, 5, 4, 19, 982630, tzinfo=datetime.timezone.utc), deleted_at=None, _children=[], _feedback=None)], _feedback=None), Call(_op_name=, trace_id='0192b29c-9827-71f0-9997-19acc42b1fb4', project_id='timssweeney/action_test_4', parent_id=None, inputs={'user_input': 'My name is Adrian.'}, id='0192b29c-9827-71f0-9997-199922a0b5a5', output='Adrian', exception=None, summary={'usage': {'gpt-3.5-turbo-0125': {'requests': 1, 'completion_tokens': 2, 'prompt_tokens': 35, 'total_tokens': 37, 'completion_tokens_details': {'reasoning_tokens': 0}, 'prompt_tokens_details': {'cached_tokens': 0}}}}, display_name=None, attributes=AttributesDict({'weave': {'client_version': '0.51.18-dev0', 'source': 'python-sdk', 'os_name': 'Darwin', 'os_version': 'Darwin Kernel Version 23.2.0: Wed Nov 15 21:53:18 PST 2023; root:xnu-10002.61.3~2/RELEASE_ARM64_T6000', 'os_release': '23.2.0', 'sys_version': '3.10.8 (main, Dec 5 2022, 18:10:41) [Clang 14.0.0 (clang-1400.0.29.202)]'}}), started_at=None, ended_at=datetime.datetime(2024, 10, 22, 5, 4, 20, 378668, tzinfo=datetime.timezone.utc), deleted_at=None, _children=[Call(_op_name=, trace_id='0192b29c-9827-71f0-9997-19acc42b1fb4', project_id='timssweeney/action_test_4', parent_id='0192b29c-9827-71f0-9997-199922a0b5a5', inputs={'self': , 'messages': [{'role': 'system', 'content': 'Extract the name from the user input. If there is no name, return an empty string.'}, {'role': 'user', 'content': 'My name is Adrian.'}], 'model': 'gpt-3.5-turbo', 'max_tokens': 64, 'temperature': 0.0, 'top_p': 1}, id='0192b29c-9828-7f50-b3f7-31a4726f4346', output=ObjectRecord({'id': 'chatcmpl-AL1M4cVHoiiAY1b9FElflF3EFwpNh', 'choices': [ObjectRecord({'finish_reason': 'stop', 'index': 0, 'logprobs': None, 'message': ObjectRecord({'content': 'Adrian', 'refusal': None, 'role': 'assistant', 'audio': None, 'function_call': None, 'tool_calls': None, '_class_name': 'ChatCompletionMessage', '_bases': ['BaseModel', 'BaseModel']}), '_class_name': 'Choice', '_bases': ['BaseModel', 'BaseModel']})], 'created': 1729573460, 'model': 'gpt-3.5-turbo-0125', 'object': 'chat.completion', 'service_tier': None, 'system_fingerprint': None, 'usage': ObjectRecord({'completion_tokens': 2, 'prompt_tokens': 35, 'total_tokens': 37, 'completion_tokens_details': ObjectRecord({'audio_tokens': None, 'reasoning_tokens': 0, '_class_name': 'CompletionTokensDetails', '_bases': ['BaseModel', 'BaseModel']}), 'prompt_tokens_details': ObjectRecord({'audio_tokens': None, 'cached_tokens': 0, '_class_name': 'PromptTokensDetails', '_bases': ['BaseModel', 'BaseModel']}), '_class_name': 'CompletionUsage', '_bases': ['BaseModel', 'BaseModel']}), '_class_name': 'ChatCompletion', '_bases': ['BaseModel', 'BaseModel']}), exception=None, summary={'usage': {'gpt-3.5-turbo-0125': {'requests': 1, 'completion_tokens': 2, 'prompt_tokens': 35, 'total_tokens': 37, 'completion_tokens_details': {'reasoning_tokens': 0}, 'prompt_tokens_details': {'cached_tokens': 0}}}}, display_name=None, attributes=AttributesDict({'weave': {'client_version': '0.51.18-dev0', 'source': 'python-sdk', 'os_name': 'Darwin', 'os_version': 'Darwin Kernel Version 23.2.0: Wed Nov 15 21:53:18 PST 2023; root:xnu-10002.61.3~2/RELEASE_ARM64_T6000', 'os_release': '23.2.0', 'sys_version': '3.10.8 (main, Dec 5 2022, 18:10:41) [Clang 14.0.0 (clang-1400.0.29.202)]'}}), started_at=None, ended_at=datetime.datetime(2024, 10, 22, 5, 4, 20, 352565, tzinfo=datetime.timezone.utc), deleted_at=None, _children=[], _feedback=None)], _feedback=None), Call(_op_name=, trace_id='0192b29c-999b-7762-93a5-f710285fa650', project_id='timssweeney/action_test_4', parent_id=None, inputs={'user_input': 'My name is Jeff.'}, id='0192b29c-999b-7762-93a5-f70678bba80b', output='Jeff', exception=None, summary={'usage': {'gpt-3.5-turbo-0125': {'requests': 1, 'completion_tokens': 1, 'prompt_tokens': 35, 'total_tokens': 36, 'completion_tokens_details': {'reasoning_tokens': 0}, 'prompt_tokens_details': {'cached_tokens': 0}}}}, display_name=None, attributes=AttributesDict({'weave': {'client_version': '0.51.18-dev0', 'source': 'python-sdk', 'os_name': 'Darwin', 'os_version': 'Darwin Kernel Version 23.2.0: Wed Nov 15 21:53:18 PST 2023; root:xnu-10002.61.3~2/RELEASE_ARM64_T6000', 'os_release': '23.2.0', 'sys_version': '3.10.8 (main, Dec 5 2022, 18:10:41) [Clang 14.0.0 (clang-1400.0.29.202)]'}}), started_at=None, ended_at=datetime.datetime(2024, 10, 22, 5, 4, 20, 705384, tzinfo=datetime.timezone.utc), deleted_at=None, _children=[Call(_op_name=, trace_id='0192b29c-999b-7762-93a5-f710285fa650', project_id='timssweeney/action_test_4', parent_id='0192b29c-999b-7762-93a5-f70678bba80b', inputs={'self': , 'messages': [{'role': 'system', 'content': 'Extract the name from the user input. If there is no name, return an empty string.'}, {'role': 'user', 'content': 'My name is Jeff.'}], 'model': 'gpt-3.5-turbo', 'max_tokens': 64, 'temperature': 0.0, 'top_p': 1}, id='0192b29c-999c-7682-89e1-df5a56d38f7b', output=ObjectRecord({'id': 'chatcmpl-AL1M48W35ElRrB3HOkUGwYTUOByje', 'choices': [ObjectRecord({'finish_reason': 'stop', 'index': 0, 'logprobs': None, 'message': ObjectRecord({'content': 'Jeff', 'refusal': None, 'role': 'assistant', 'audio': None, 'function_call': None, 'tool_calls': None, '_class_name': 'ChatCompletionMessage', '_bases': ['BaseModel', 'BaseModel']}), '_class_name': 'Choice', '_bases': ['BaseModel', 'BaseModel']})], 'created': 1729573460, 'model': 'gpt-3.5-turbo-0125', 'object': 'chat.completion', 'service_tier': None, 'system_fingerprint': None, 'usage': ObjectRecord({'completion_tokens': 1, 'prompt_tokens': 35, 'total_tokens': 36, 'completion_tokens_details': ObjectRecord({'audio_tokens': None, 'reasoning_tokens': 0, '_class_name': 'CompletionTokensDetails', '_bases': ['BaseModel', 'BaseModel']}), 'prompt_tokens_details': ObjectRecord({'audio_tokens': None, 'cached_tokens': 0, '_class_name': 'PromptTokensDetails', '_bases': ['BaseModel', 'BaseModel']}), '_class_name': 'CompletionUsage', '_bases': ['BaseModel', 'BaseModel']}), '_class_name': 'ChatCompletion', '_bases': ['BaseModel', 'BaseModel']}), exception=None, summary={'usage': {'gpt-3.5-turbo-0125': {'requests': 1, 'completion_tokens': 1, 'prompt_tokens': 35, 'total_tokens': 36, 'completion_tokens_details': {'reasoning_tokens': 0}, 'prompt_tokens_details': {'cached_tokens': 0}}}}, display_name=None, attributes=AttributesDict({'weave': {'client_version': '0.51.18-dev0', 'source': 'python-sdk', 'os_name': 'Darwin', 'os_version': 'Darwin Kernel Version 23.2.0: Wed Nov 15 21:53:18 PST 2023; root:xnu-10002.61.3~2/RELEASE_ARM64_T6000', 'os_release': '23.2.0', 'sys_version': '3.10.8 (main, Dec 5 2022, 18:10:41) [Clang 14.0.0 (clang-1400.0.29.202)]'}}), started_at=None, ended_at=datetime.datetime(2024, 10, 22, 5, 4, 20, 684769, tzinfo=datetime.timezone.utc), deleted_at=None, _children=[], _feedback=None)], _feedback=None), Call(_op_name=, trace_id='0192b29c-9ae2-70c1-a03b-ee61a00a9be4', project_id='timssweeney/action_test_4', parent_id=None, inputs={'user_input': 'My name is Shawn.'}, id='0192b29c-9ae2-70c1-a03b-ee5f915e48fe', output='Shawn', exception=None, summary={'usage': {'gpt-3.5-turbo-0125': {'requests': 1, 'completion_tokens': 2, 'prompt_tokens': 35, 'total_tokens': 37, 'completion_tokens_details': {'reasoning_tokens': 0}, 'prompt_tokens_details': {'cached_tokens': 0}}}}, display_name=None, attributes=AttributesDict({'weave': {'client_version': '0.51.18-dev0', 'source': 'python-sdk', 'os_name': 'Darwin', 'os_version': 'Darwin Kernel Version 23.2.0: Wed Nov 15 21:53:18 PST 2023; root:xnu-10002.61.3~2/RELEASE_ARM64_T6000', 'os_release': '23.2.0', 'sys_version': '3.10.8 (main, Dec 5 2022, 18:10:41) [Clang 14.0.0 (clang-1400.0.29.202)]'}}), started_at=None, ended_at=datetime.datetime(2024, 10, 22, 5, 4, 21, 123765, tzinfo=datetime.timezone.utc), deleted_at=None, _children=[Call(_op_name=, trace_id='0192b29c-9ae2-70c1-a03b-ee61a00a9be4', project_id='timssweeney/action_test_4', parent_id='0192b29c-9ae2-70c1-a03b-ee5f915e48fe', inputs={'self': , 'messages': [{'role': 'system', 'content': 'Extract the name from the user input. If there is no name, return an empty string.'}, {'role': 'user', 'content': 'My name is Shawn.'}], 'model': 'gpt-3.5-turbo', 'max_tokens': 64, 'temperature': 0.0, 'top_p': 1}, id='0192b29c-9ae2-70c1-a03b-ee7deb9b5c66', output=ObjectRecord({'id': 'chatcmpl-AL1M5JnVQ8K82EcgaghB8xkcpmeZY', 'choices': [ObjectRecord({'finish_reason': 'stop', 'index': 0, 'logprobs': None, 'message': ObjectRecord({'content': 'Shawn', 'refusal': None, 'role': 'assistant', 'audio': None, 'function_call': None, 'tool_calls': None, '_class_name': 'ChatCompletionMessage', '_bases': ['BaseModel', 'BaseModel']}), '_class_name': 'Choice', '_bases': ['BaseModel', 'BaseModel']})], 'created': 1729573461, 'model': 'gpt-3.5-turbo-0125', 'object': 'chat.completion', 'service_tier': None, 'system_fingerprint': None, 'usage': ObjectRecord({'completion_tokens': 2, 'prompt_tokens': 35, 'total_tokens': 37, 'completion_tokens_details': ObjectRecord({'audio_tokens': None, 'reasoning_tokens': 0, '_class_name': 'CompletionTokensDetails', '_bases': ['BaseModel', 'BaseModel']}), 'prompt_tokens_details': ObjectRecord({'audio_tokens': None, 'cached_tokens': 0, '_class_name': 'PromptTokensDetails', '_bases': ['BaseModel', 'BaseModel']}), '_class_name': 'CompletionUsage', '_bases': ['BaseModel', 'BaseModel']}), '_class_name': 'ChatCompletion', '_bases': ['BaseModel', 'BaseModel']}), exception=None, summary={'usage': {'gpt-3.5-turbo-0125': {'requests': 1, 'completion_tokens': 2, 'prompt_tokens': 35, 'total_tokens': 37, 'completion_tokens_details': {'reasoning_tokens': 0}, 'prompt_tokens_details': {'cached_tokens': 0}}}}, display_name=None, attributes=AttributesDict({'weave': {'client_version': '0.51.18-dev0', 'source': 'python-sdk', 'os_name': 'Darwin', 'os_version': 'Darwin Kernel Version 23.2.0: Wed Nov 15 21:53:18 PST 2023; root:xnu-10002.61.3~2/RELEASE_ARM64_T6000', 'os_release': '23.2.0', 'sys_version': '3.10.8 (main, Dec 5 2022, 18:10:41) [Clang 14.0.0 (clang-1400.0.29.202)]'}}), started_at=None, ended_at=datetime.datetime(2024, 10, 22, 5, 4, 21, 96048, tzinfo=datetime.timezone.utc), deleted_at=None, _children=[], _feedback=None)], _feedback=None)]\n" + ] + } + ], + "source": [ + "from openai import OpenAI\n", + "\n", + "openai_client = OpenAI()\n", + "\n", + "\n", + "@weave.op\n", + "def extract_name(user_input: str) -> str:\n", + " if \"Tim\" in user_input:\n", + " return \"i don't know!\"\n", + " response = openai_client.chat.completions.create(\n", + " model=\"gpt-3.5-turbo\",\n", + " messages=[\n", + " {\n", + " \"role\": \"system\",\n", + " \"content\": \"Extract the name from the user input. If there is no name, return an empty string.\",\n", + " },\n", + " {\"role\": \"user\", \"content\": user_input},\n", + " ],\n", + " temperature=0.0,\n", + " max_tokens=64,\n", + " top_p=1,\n", + " )\n", + " return response.choices[0].message.content\n", + "\n", + "\n", + "calls = []\n", + "names = [\"Tim\", \"Scott\", \"Adrian\", \"Jeff\", \"Shawn\"]\n", + "for name in names:\n", + " res, call = extract_name.call(f\"My name is {name}.\")\n", + " calls.append(call)\n", + "\n", + "print(calls)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "from weave.collection_objects import action_objects\n", + "from weave.trace.weave_client import get_ref\n", + "from weave.trace_server.interface.collections import action_collection\n", + "\n", + "action = action_objects.ActionWithConfig(\n", + " name=\"is_name_extracted\",\n", + " action=action_collection._BuiltinAction(\n", + " name=\"openai_completion\",\n", + " ),\n", + " config={\n", + " \"model\": \"gpt-4o-mini\",\n", + " \"system_prompt\": \"Given the following prompt and response, determine if the name was extracted correctly.\",\n", + " \"response_format\": {\n", + " \"type\": \"json_schema\",\n", + " \"json_schema\": {\n", + " \"name\": \"is_name_extracted\",\n", + " \"schema\": {\n", + " \"type\": \"object\",\n", + " \"properties\": {\"is_extracted\": {\"type\": \"boolean\"}},\n", + " \"required\": [\"is_extracted\"],\n", + " \"additionalProperties\": False,\n", + " },\n", + " \"strict\": True,\n", + " },\n", + " },\n", + " },\n", + ")\n", + "mapping = action_objects.ActionOpMapping(\n", + " action=action,\n", + " op_name=get_ref(extract_name).name,\n", + " op_digest=get_ref(extract_name).digest,\n", + " input_mapping={\n", + " \"prompt\": \"inputs.user_input\",\n", + " \"response\": \"output\",\n", + " },\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "from weave.trace_server import trace_server_interface as tsi\n", + "\n", + "res = client.server.execute_batch_action(\n", + " req=tsi.ExecuteBatchActionReq(\n", + " project_id=client._project_id(), call_ids=[c.id for c in calls], mapping=mapping\n", + " )\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "wandb-weave", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.8" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tests/trace/test_actions.py b/tests/trace/test_actions.py new file mode 100644 index 0000000000..6537fda5b2 --- /dev/null +++ b/tests/trace/test_actions.py @@ -0,0 +1,154 @@ +import os + +from openai import OpenAI + +import weave +from weave.collection_objects import action_objects +from weave.trace.weave_client import WeaveClient, get_ref +from weave.trace_server import trace_server_interface as tsi +from weave.trace_server.interface.collections import action_collection +from weave.trace_server.interface.feedback_types.action_feedback_type import ( + ACTION_FEEDBACK_TYPE_NAME, +) + + +def test_action_create(client: WeaveClient): + api_key = os.environ.get("OPENAI_API_KEY", "DUMMY_API_KEY") + + openai_client = OpenAI(api_key=api_key) + + @weave.op + def extract_name(user_input: str) -> str: + response = openai_client.chat.completions.create( + model="gpt-3.5-turbo", + messages=[ + { + "role": "system", + "content": "Extract the name from the user input. If there is no name, return an empty string.", + }, + {"role": "user", "content": user_input}, + ], + temperature=0.0, + max_tokens=64, + top_p=1, + ) + return response.choices[0].message.content + + res, call = extract_name.call("My name is Tim.") + + action = action_objects.ActionWithConfigObject( + name="is_name_extracted", + action=action_collection._BuiltinAction( + name="openai_completion", + ), + config={ + "model": "gpt-4o-mini", + "system_prompt": "Given the following prompt and response, determine if the name was extracted correctly.", + "response_format": { + "type": "json_schema", + "json_schema": { + "name": "is_name_extracted", + "schema": { + "type": "object", + "properties": {"is_extracted": {"type": "boolean"}}, + "required": ["is_extracted"], + "additionalProperties": False, + }, + "strict": True, + }, + }, + }, + ) + + mapping = action_objects.ActionOpMappingObject( + name="extract_name-is_name_extracted", + action=action, + op_name=get_ref(extract_name).name, + op_digest=get_ref(extract_name).digest, + input_mapping={ + "prompt": "inputs.user_input", + "response": "output", + }, + ) + req = tsi.ExecuteBatchActionReq( + project_id=client._project_id(), call_ids=[call.id], mapping=mapping + ) + + res = client.server.execute_batch_action(req=req) + + # AFTER CALL! + weave.publish(mapping) + + gotten_call = client.server.calls_query( + req=tsi.CallsQueryReq( + project_id=client._project_id(), call_ids=[call.id], include_feedback=True + ) + ) + assert len(gotten_call.calls) == 2 + target_call = gotten_call.calls[0] + + assert target_call.op_name == get_ref(extract_name).uri() + feedbacks = target_call.summary["weave"]["feedback"] + assert len(feedbacks) == 1 + feedback = feedbacks[0] + assert feedback["feedback_type"] == ACTION_FEEDBACK_TYPE_NAME + assert feedback["payload"]["name"] == "is_name_extracted" + assert feedback["payload"]["action_mapping_ref"] == get_ref(mapping).uri() + assert feedback["payload"]["results"] == {"is_extracted": True} + + +# def test_builtin_actions(client: WeaveClient): +# actions = client.server.actions_list() +# assert len(actions) > 0 + + +# def test_action_flow(client: WeaveClient): +# # 1. Bootstrap builtin actions +# # 2. Query Available Actions +# # Run an op +# # 3. Create a 1-off batch action using mapping +# # 4. Create an online trigger +# # Run more ops +# # 5. Query back the feedback results. +# pass + + +""" +Framing: + +1. We support a number of functions that serve as scorers from a standard lib like https://docs.ragas.io/en/stable/concepts/metrics +2. Each scorer can have a config to configure the rules of the scorer (think of this like a closure) +3. When executing a scorer, we will need to define a mapping for an op (inputs and outputs) to the specific fields + + +(Scorers - Hard coded, but versioned non-the-less) +Mapping (Mapping from Op to Scorer fields) +Run (single / Batch) - not saved, needs config +Online - query/filter, sample rate, scorer, config, mapping, op + + + +Spec: + +""" + +# Shouldn't actually put thiese in the user space +# input_schema=actions.JSONSchema( +# schema={ +# "type": "object", +# "properties": {"prompt": {"type": "string"}}, +# "required": ["prompt"], +# "additionalProperties": False, +# } +# ), +# config_schema=actions.JSONSchema( +# schema={ +# "type": "object", +# "properties": { +# "system_prompt": {"type": "string"}, +# "response_format": {"type": "object"}, +# }, +# "required": ["system_prompt"], +# "additionalProperties": False, +# } +# ), diff --git a/weave/collection_objects/action_objects.py b/weave/collection_objects/action_objects.py new file mode 100644 index 0000000000..7f41e538b1 --- /dev/null +++ b/weave/collection_objects/action_objects.py @@ -0,0 +1,13 @@ +import weave +from weave.trace_server.interface.collections.action_collection import ( + ActionOpMapping, + ActionWithConfig, +) + + +class ActionWithConfigObject(weave.Object, ActionWithConfig): + pass + + +class ActionOpMappingObject(weave.Object, ActionOpMapping): + pass diff --git a/weave/trace_server/clickhouse_trace_server_batched.py b/weave/trace_server/clickhouse_trace_server_batched.py index 2f828088a7..aac7df09ef 100644 --- a/weave/trace_server/clickhouse_trace_server_batched.py +++ b/weave/trace_server/clickhouse_trace_server_batched.py @@ -30,6 +30,7 @@ import threading from collections import defaultdict from contextlib import contextmanager +from functools import partial from typing import ( Any, Dict, @@ -48,6 +49,7 @@ from clickhouse_connect.driver.client import Client as CHClient from clickhouse_connect.driver.query import QueryResult from clickhouse_connect.driver.summary import QuerySummary +from pydantic import BaseModel from weave.trace_server import clickhouse_trace_server_migrator as wf_migrator from weave.trace_server import environment as wf_env @@ -78,6 +80,17 @@ validate_feedback_purge_req, ) from weave.trace_server.ids import generate_id +from weave.trace_server.interface.collections.action_collection import ( + ActionOpMapping, + ActionWithConfig, +) +from weave.trace_server.interface.collections.collection import ( + make_python_object_from_dict, +) +from weave.trace_server.interface.feedback_types.action_feedback_type import ( + ACTION_FEEDBACK_TYPE_NAME, + ActionFeedback, +) from weave.trace_server.orm import ParamBuilder, Row from weave.trace_server.table_query_builder import ( ROW_ORDER_COLUMN_NAME, @@ -566,16 +579,33 @@ def ops_query(self, req: tsi.OpQueryReq) -> tsi.OpQueryRes: return tsi.OpQueryRes(op_objs=objs) def obj_create(self, req: tsi.ObjCreateReq) -> tsi.ObjCreateRes: - json_val = json.dumps(req.obj.val) - digest = str_digest(json_val) - req_obj = req.obj + dict_val = req_obj.val + + from weave.trace_server.interface.collections.collection_registry import ( + collections, + ) + + if req.collection_name: + for cr in collections: + if cr.name == req.collection_name: + dict_val = make_python_object_from_dict( + cr.name, + cr.base_model_spec.__name__, + ["BaseModel"], + dict_val, + ) + break + + json_val = json.dumps(dict_val) + digest = str_digest(json_val) + print("!", req_obj.object_id, digest, json_val) ch_obj = ObjCHInsertable( project_id=req_obj.project_id, object_id=req_obj.object_id, - kind=get_kind(req.obj.val), - base_object_class=get_base_object_class(req.obj.val), - refs=extract_refs_from_values(req.obj.val), + kind=get_kind(req_obj.val), + base_object_class=get_base_object_class(req_obj.val), + refs=extract_refs_from_values(req_obj.val), val_dump=json_val, digest=digest, ) @@ -1318,8 +1348,19 @@ def feedback_create(self, req: tsi.FeedbackCreateReq) -> tsi.FeedbackCreateRes: assert_non_null_wb_user_id(req) validate_feedback_create_req(req) + feedback_type = req.feedback_type + res_payload = req.payload + # move to top of file + from weave.trace_server.interface.feedback_types.feedback_type_registry import ( + feedback_types, + ) + + for ft in feedback_types: + if ft.name == feedback_type: + res_payload = ft.payload_spec.model_validate(res_payload).model_dump() + break + # Augment emoji with alias. - res_payload = {} if req.feedback_type == "wandb.reaction.1": em = req.payload["emoji"] if emoji.emoji_count(em) != 1: @@ -1385,6 +1426,176 @@ def feedback_purge(self, req: tsi.FeedbackPurgeReq) -> tsi.FeedbackPurgeRes: self.ch_client.query(prepared.sql, prepared.parameters) return tsi.FeedbackPurgeRes() + def execute_batch_action( + self, req: tsi.ExecuteBatchActionReq + ) -> tsi.ExecuteBatchActionRes: + if req.mapping is None and req.mapping_ref is None: + raise InvalidRequest("Either mapping or mapping_ref must be provided") + if req.mapping is not None and req.mapping_ref is not None: + raise InvalidRequest("Only one of mapping or mapping_ref can be provided") + + mapping: ActionOpMapping + mapping_ref: str + if req.mapping is not None: + mapping = req.mapping + from weave.trace_server.interface.collections.action_collection import ( + action_op_mapping_collection, + action_with_config_collection, + ) + + action_ref: str + if isinstance(mapping.action, dict): + mapping.action = ActionWithConfig.model_validate(mapping.action) + if isinstance(mapping.action, ActionWithConfig): + action_digest = self.obj_create( + tsi.ObjCreateReq( + collection_name=action_with_config_collection.name, + obj=tsi.ObjSchemaForInsert( + project_id=req.project_id, + object_id=mapping.action.name, + val=mapping.action.model_dump(), + ), + ) + ).digest + action_ref = ri.InternalObjectRef( + project_id=req.project_id, + name=mapping.action.name, + version=action_digest, + ).uri() + elif isinstance(mapping.action, str): + action_ref = mapping.action + else: + raise InvalidRequest("Invalid action") + + mapping_val = mapping.model_dump() + mapping_val["action"] = action_ref # YUK YUK YUK YUK + + digest = self.obj_create( + tsi.ObjCreateReq( + collection_name=action_op_mapping_collection.name, + obj=tsi.ObjSchemaForInsert( + project_id=req.project_id, + object_id=mapping.name, + val=mapping_val, + ), + ) + ).digest + mapping_ref = ri.InternalObjectRef( + project_id=req.project_id, + name=mapping.name, + version=digest, + ).uri() + elif req.mapping_ref is not None: + mapping_ref = req.mapping_ref + mapping_val_res = self.refs_read_batch( + tsi.RefsReadBatchReq(refs=[req.mapping_ref]) + ) + mapping_val = mapping_val_res.vals[0] + maybe_action = mapping_val.get("action") + if isinstance(maybe_action, dict): + action_dict = maybe_action + elif isinstance(maybe_action, str): + action_dict_res = self.refs_read_batch( + tsi.RefsReadBatchReq(refs=[req.mapping_ref]) + ) + action_dict = action_dict_res.vals[0] + + mapping_val["action"] = action_dict + mapping = tsi.ActionOpMapping.model_validate(mapping_val) + else: + raise InvalidRequest("Either mapping or mapping_ref must be provided") + + if mapping.action.action.action_type != "builtin": + raise InvalidRequest( + "Only builtin actions are supported for batch execution" + ) + + if mapping.action.action.name != "openai_completion": + raise InvalidRequest( + "Only openai_completion is supported for batch execution" + ) + + if mapping.action.action.digest != "*": + raise InvalidRequest("Digest must be '*' for batch execution") + + # Step 1: Get all the calls in the batch + calls = self.calls_query_stream( + tsi.CallsQueryReq( + project_id=req.project_id, + filter=tsi.CallsFilter( + call_ids=req.call_ids, + op_names=[ + ri.InternalOpRef( + project_id=req.project_id, + name=mapping.op_name, + version=mapping.op_digest, + ).uri(), + ], + ), + ) + ) + + # Normally we would dispatch here, but just hard coding for now + # We should do some validation here + config = mapping.action.config + model = config["model"] + system_prompt = config["system_prompt"] + response_format = config["response_format"] + action_name = mapping.action.name + + mapping = mapping.input_mapping + + # Step 2: For Each call, execute the action: (this needs a lot of safety checks) + for call in calls: + args = {} + for input_name, call_selector in mapping.items(): + call_selector_parts = call_selector.split(".") + val = call + for part in call_selector_parts: + if isinstance(val, dict): + val = val[part] + elif isinstance(val, list): + val = val[int(part)] + elif isinstance(val, BaseModel): + val = getattr(val, part) + else: + raise InvalidRequest(f"Invalid call selector: {call_selector}") + args[input_name] = val + + from openai import OpenAI + + client = OpenAI() + # Silly hack to get around issue in tests: + create = client.chat.completions.create + if hasattr(create, "resolve_fn"): + create = partial(create.resolve_fn, self=client.chat.completions) + completion = create( + model=model, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": json.dumps(args)}, + ], + response_format=response_format, + ) + self.feedback_create( + tsi.FeedbackCreateReq( + project_id=req.project_id, + weave_ref=ri.InternalCallRef( + project_id=req.project_id, + id=call.id, + ).uri(), + feedback_type=ACTION_FEEDBACK_TYPE_NAME, + wb_user_id="ACTIONS_BOT", # - THIS IS NOT GOOD! + payload=ActionFeedback( + name=action_name, + action_mapping_ref=mapping_ref, + results=json.loads(completion.choices[0].message.content), + ).model_dump(), + ) + ) + + return tsi.ExecuteBatchActionRes() + # Private Methods @property def ch_client(self) -> CHClient: @@ -1948,7 +2159,7 @@ def _process_parameters( def get_type(val: Any) -> str: - if val == None: + if val is None: return "none" elif isinstance(val, dict): if "_type" in val: diff --git a/weave/trace_server/external_to_internal_trace_server_adapter.py b/weave/trace_server/external_to_internal_trace_server_adapter.py index 588fd56dfa..1acdb0c55c 100644 --- a/weave/trace_server/external_to_internal_trace_server_adapter.py +++ b/weave/trace_server/external_to_internal_trace_server_adapter.py @@ -345,3 +345,9 @@ def cost_query(self, req: tsi.CostQueryReq) -> tsi.CostQueryRes: raise ValueError("Internal Error - Project Mismatch") cost["pricing_level_id"] = original_project_id return res + + def execute_batch_action( + self, req: tsi.ExecuteBatchActionReq + ) -> tsi.ExecuteBatchActionRes: + req.project_id = self._idc.ext_to_int_project_id(req.project_id) + return self._ref_apply(self._internal_trace_server.execute_batch_action, req) diff --git a/weave/trace_server/interface/collections/action_collection.py b/weave/trace_server/interface/collections/action_collection.py new file mode 100644 index 0000000000..fabb92dc38 --- /dev/null +++ b/weave/trace_server/interface/collections/action_collection.py @@ -0,0 +1,54 @@ +from typing import Literal + +from pydantic import BaseModel + +from weave.trace_server.interface.collections.collection import Collection + +# class JSONSchema(BaseModel): +# schema: dict + + +# This is only here for completeness, I think we are going to hardcode a list for now so they don't need to exist in every project +class _BuiltinAction(BaseModel): + action_type: Literal["builtin"] = "builtin" + name: str + digest: str = "*" + # input_schema: JSONSchema + # config_schema: JSONSchema + + +class ActionWithConfig(BaseModel): + name: str + action: _BuiltinAction + config: dict + + +# # Future +# class OpAction(Action): +# action_type: Literal["op"] +# op: Op + + +class ActionOpMapping(BaseModel): + name: str # uggg, want to get rid of this + action: ActionWithConfig + op_name: str + op_digest: str + input_mapping: dict[str, str] # Input field name -> Call selector + + +# class ActionFilterTrigger(BaseModel): +# attribute_filter: dict[str, str] # Could the CallFilter. +# sample_rate: float +# mapping: ActionOpMapping +# config: dict + +action_with_config_collection = Collection( + name="ActionWithConfig", + base_model_spec=ActionWithConfig, +) + +action_op_mapping_collection = Collection( + name="ActionOpMapping", + base_model_spec=ActionOpMapping, +) diff --git a/weave/trace_server/interface/collections/collection.py b/weave/trace_server/interface/collections/collection.py new file mode 100644 index 0000000000..18f0b7248c --- /dev/null +++ b/weave/trace_server/interface/collections/collection.py @@ -0,0 +1,19 @@ +from typing import Type + +from pydantic import BaseModel + + +class Collection(BaseModel): + name: str + base_model_spec: Type[BaseModel] + + +def make_python_object_from_dict( + type_name: str, class_name: str, bases: list[str], dict_val: dict +) -> BaseModel: + return { + "_type": type_name, + **dict_val, + # "_class_name": class_name, + # "_bases": bases, + } diff --git a/weave/trace_server/interface/collections/collection_registry.py b/weave/trace_server/interface/collections/collection_registry.py new file mode 100644 index 0000000000..e44684450b --- /dev/null +++ b/weave/trace_server/interface/collections/collection_registry.py @@ -0,0 +1,10 @@ +from weave.trace_server.interface.collections.action_collection import ( + action_op_mapping_collection, + action_with_config_collection, +) +from weave.trace_server.interface.collections.collection import Collection + +collections: list[Collection] = [ + action_with_config_collection, + action_op_mapping_collection, +] diff --git a/weave/trace_server/interface/feedback_types/action_feedback_type.py b/weave/trace_server/interface/feedback_types/action_feedback_type.py new file mode 100644 index 0000000000..f921b66fba --- /dev/null +++ b/weave/trace_server/interface/feedback_types/action_feedback_type.py @@ -0,0 +1,17 @@ +from pydantic import BaseModel + +from weave.trace_server.interface.feedback_types.feedback_type import FeedbackType + +ACTION_FEEDBACK_TYPE_NAME = "wandb.action.1" + + +class ActionFeedback(BaseModel): + name: str + action_mapping_ref: str + results: dict + + +action_feedback_type = FeedbackType( + name=ACTION_FEEDBACK_TYPE_NAME, + payload_spec=ActionFeedback, +) diff --git a/weave/trace_server/interface/feedback_types/feedback_type.py b/weave/trace_server/interface/feedback_types/feedback_type.py new file mode 100644 index 0000000000..ce32755055 --- /dev/null +++ b/weave/trace_server/interface/feedback_types/feedback_type.py @@ -0,0 +1,8 @@ +from typing import Type + +from pydantic import BaseModel + + +class FeedbackType(BaseModel): + name: str + payload_spec: Type[BaseModel] diff --git a/weave/trace_server/interface/feedback_types/feedback_type_registry.py b/weave/trace_server/interface/feedback_types/feedback_type_registry.py new file mode 100644 index 0000000000..f473789053 --- /dev/null +++ b/weave/trace_server/interface/feedback_types/feedback_type_registry.py @@ -0,0 +1,6 @@ +from weave.trace_server.interface.feedback_types.action_feedback_type import ( + action_feedback_type, +) +from weave.trace_server.interface.feedback_types.feedback_type import FeedbackType + +feedback_types: list[FeedbackType] = [action_feedback_type] diff --git a/weave/trace_server/sqlite_trace_server.py b/weave/trace_server/sqlite_trace_server.py index 4df2fc19c5..bef1319839 100644 --- a/weave/trace_server/sqlite_trace_server.py +++ b/weave/trace_server/sqlite_trace_server.py @@ -1081,6 +1081,11 @@ def cost_purge(self, req: tsi.CostPurgeReq) -> tsi.CostPurgeRes: print("COST PURGE is not implemented for local sqlite", req) return tsi.CostPurgeRes() + def execute_batch_action( + self, req: tsi.ExecuteBatchActionReq + ) -> tsi.ExecuteBatchActionRes: + pass + def _table_row_read(self, project_id: str, row_digest: str) -> tsi.TableRowSchema: conn, cursor = get_conn_cursor(self.db_path) # Now get the rows diff --git a/weave/trace_server/trace_server_interface.py b/weave/trace_server/trace_server_interface.py index 442ed223ca..dedf811f78 100644 --- a/weave/trace_server/trace_server_interface.py +++ b/weave/trace_server/trace_server_interface.py @@ -1,10 +1,12 @@ import datetime +import typing from enum import Enum from typing import Any, Dict, Iterator, List, Literal, Optional, Protocol, Union from pydantic import BaseModel, ConfigDict, Field, field_serializer from typing_extensions import TypedDict +from weave.trace_server.interface.collections.action_collection import ActionOpMapping from weave.trace_server.interface.query import Query WB_USER_ID_DESCRIPTION = ( @@ -351,6 +353,7 @@ class OpQueryRes(BaseModel): class ObjCreateReq(BaseModel): obj: ObjSchemaForInsert + collection_name: Optional[str] = None class ObjCreateRes(BaseModel): @@ -796,6 +799,17 @@ class CostPurgeRes(BaseModel): pass +class ExecuteBatchActionReq(BaseModel): + project_id: str + call_ids: list[str] + mapping: typing.Optional[ActionOpMapping] = None + mapping_ref: typing.Optional[str] = None + + +class ExecuteBatchActionRes(BaseModel): + pass + + class TraceServerInterface(Protocol): def ensure_project_exists( self, entity: str, project: str @@ -837,3 +851,35 @@ def file_content_read(self, req: FileContentReadReq) -> FileContentReadRes: ... def feedback_create(self, req: FeedbackCreateReq) -> FeedbackCreateRes: ... def feedback_query(self, req: FeedbackQueryReq) -> FeedbackQueryRes: ... def feedback_purge(self, req: FeedbackPurgeReq) -> FeedbackPurgeRes: ... + + # Action API + def execute_batch_action( + self, req: ExecuteBatchActionReq + ) -> ExecuteBatchActionRes: ... + + # Tim's Custom ideas + # def create_collection_object(self, req: CreateCollectionObjectReq) -> CreateCollectionObjectRes: ... + # def create_feedback_entry(self, req: CreateFeedbackEntryReq) -> CreateFeedbackEntryRes: ... + + # def action_create(self, req: ActionCreateReq) -> ActionCreateRes: ... + # def action_read(self, req: ActionReadReq) -> ActionReadRes: ... + # def actions_list(self) -> list[Action]: ... + + +# class CreateCollectionObjectReq(BaseModel): +# project_id: str +# object_id: str +# collection_name: str +# payload: dict + +# class CreateCollectionObjectRes(BaseModel): +# digest: str + +# class CreateFeedbackEntryReq(BaseModel): +# project_id: str = Field(examples=["entity/project"]) +# weave_ref: str = Field(examples=["weave:///entity/project/object/name:digest"]) +# creator: Optional[str] = Field(default=None, examples=["Jane Smith"]) +# feedback_type: str = Field(examples=["custom"]) +# payload: dict +# # wb_user_id is automatically populated by the server +# wb_user_id: Optional[str] = Field(None, description=WB_USER_ID_DESCRIPTION) diff --git a/weave/trace_server_bindings/remote_http_trace_server.py b/weave/trace_server_bindings/remote_http_trace_server.py index 94228574d2..2b8dc7ae17 100644 --- a/weave/trace_server_bindings/remote_http_trace_server.py +++ b/weave/trace_server_bindings/remote_http_trace_server.py @@ -265,7 +265,7 @@ def call_start( req_as_obj = tsi.CallStartReq.model_validate(req) else: req_as_obj = req - if req_as_obj.start.id == None or req_as_obj.start.trace_id == None: + if req_as_obj.start.id is None or req_as_obj.start.trace_id is None: raise ValueError( "CallStartReq must have id and trace_id when batching." ) @@ -549,6 +549,16 @@ def cost_purge( "/cost/purge", req, tsi.CostPurgeReq, tsi.CostPurgeRes ) + def execute_batch_action( + self, req: tsi.ExecuteBatchActionReq + ) -> tsi.ExecuteBatchActionRes: + return self._generic_request( + "/execute/batch_action", + req, + tsi.ExecuteBatchActionReq, + tsi.ExecuteBatchActionRes, + ) + __docspec__ = [ RemoteHTTPTraceServer,