diff --git a/.github/workflows/common_checks.yaml b/.github/workflows/common_checks.yaml index 1c03be86..4c8c0a6b 100644 --- a/.github/workflows/common_checks.yaml +++ b/.github/workflows/common_checks.yaml @@ -154,3 +154,30 @@ jobs: tar -xzf gitleaks_8.10.1_linux_x64.tar.gz && \ sudo install gitleaks /usr/bin && \ gitleaks detect --report-format json --report-path leak_report -v + + tools_checks: + continue-on-error: False + runs-on: ${{ matrix.os }} + + strategy: + matrix: + os: [ ubuntu-latest ] + python-version: [ "3.10" ] + + timeout-minutes: 30 + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + - uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + sudo apt-get update --fix-missing + sudo apt-get autoremove + sudo apt-get autoclean + pip install tomte[tox,cli]==0.2.15 + pip install --user --upgrade setuptools + - name: Tool unit tests + run: tox -e check-tools \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..8afb6123 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,19 @@ +# -*- coding: utf-8 -*- +# ------------------------------------------------------------------------------ +# +# Copyright 2024 Valory AG +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ------------------------------------------------------------------------------ +"""This module contains tests.""" diff --git a/tests/constants.py b/tests/constants.py new file mode 100644 index 00000000..b2aef2ec --- /dev/null +++ b/tests/constants.py @@ -0,0 +1,29 @@ +# -*- coding: utf-8 -*- +# ------------------------------------------------------------------------------ +# +# Copyright 2024 Valory AG +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ------------------------------------------------------------------------------ +"""This module contains constants.""" + +import os + +OPENAI_SECRET_KEY = os.getenv("OPENAI_SECRET_KEY") +STABILITY_API_KEY = os.getenv("STABILITY_API_KEY") +GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") +GOOGLE_ENGINE_ID = os.getenv("GOOGLE_ENGINE_ID") +CLAUDE_API_KEY = os.getenv("CLAUDE_API_KEY") +REPLICATE_API_KEY = os.getenv("REPLICATE_API_KEY") +NEWS_API_KEY = os.getenv("NEWS_API_KEY") diff --git a/tests/test_tools.py b/tests/test_tools.py new file mode 100644 index 00000000..04e1f379 --- /dev/null +++ b/tests/test_tools.py @@ -0,0 +1,91 @@ +# -*- coding: utf-8 -*- +# ------------------------------------------------------------------------------ +# +# Copyright 2024 Valory AG +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ------------------------------------------------------------------------------ +"""This module contains tool tests.""" + +from typing import List, Any + +from packages.valory.customs.prediction_request_claude import prediction_request_claude +from packages.valory.skills.task_execution.utils.benchmarks import TokenCounterCallback +from tests.constants import ( + OPENAI_SECRET_KEY, + STABILITY_API_KEY, + GOOGLE_API_KEY, + GOOGLE_ENGINE_ID, + CLAUDE_API_KEY, + REPLICATE_API_KEY, + NEWS_API_KEY, +) + + +class BaseToolTest: + """Base tool test class.""" + keys = { + "openai": OPENAI_SECRET_KEY, + "stabilityai": STABILITY_API_KEY, + "google_api_key": GOOGLE_API_KEY, + "google_engine_id": GOOGLE_ENGINE_ID, + "anthropic": CLAUDE_API_KEY, + "replicate": REPLICATE_API_KEY, + "newsapi": NEWS_API_KEY, + } + models: List = [None] + tools: List[str] + prompts: List[str] + tool_module: Any = None + tool_callable: str = "run" + + def _validate_response(self, response: Any) -> None: + """Validate response.""" + assert type(response) == tuple, "Response of the tool must be a tuple." + assert len(response) == 4, "Response must have 4 elements." + assert type(response[0]) == str, "Response[0] must be a string." + assert type(response[1]) == str, "Response[1] must be a string." + assert type(response[2]) == dict or response[2] is None, "Response[2] must be a dictionary or None." + assert type(response[3]) == TokenCounterCallback or response[3] is None, "Response[3] must be a TokenCounterCallback or None." + + def test_run(self) -> None: + """Test run method.""" + assert self.tools, "Tools must be provided." + assert self.prompts, "Prompts must be provided." + assert self.tool_module, "Callable function must be provided." + + for model in self.models: + for tool in self.tools: + for prompt in self.prompts: + kwargs = dict( + prompt=prompt, + tool=tool, + api_keys=self.keys, + counter_callback=TokenCounterCallback(), + model=model, + ) + func = getattr(self.tool_module, self.tool_callable) + response = func(**kwargs) + self._validate_response(response) + + +class TestClaudePredictionOnline(BaseToolTest): + """Test Claude Prediction Online.""" + + tools = prediction_request_claude.ALLOWED_TOOLS + models = prediction_request_claude.ALLOWED_MODELS + prompts = [ + "Please take over the role of a Data Scientist to evaluate the given question. With the given question \"Will Apple release iPhone 17 by March 2025?\" and the `yes` option represented by `Yes` and the `no` option represented by `No`, what are the respective probabilities of `p_yes` and `p_no` occurring?" + ] + tool_module = prediction_request_claude diff --git a/tox.ini b/tox.ini index 22119e5d..9ba66b1f 100644 --- a/tox.ini +++ b/tox.ini @@ -241,6 +241,10 @@ commands = aea packages sync {toxinidir}/scripts/check_doc_ipfs_hashes.py +[testenv:check-tools] +deps = {[deps-packages]deps} +commands = pytest tests + [testenv:fix-doc-hashes] skipsdist = True skip_install = True