Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve tests #1486

Merged
merged 2 commits into from
Jan 21, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions .github/workflows/copilot.yml
Original file line number Diff line number Diff line change
@@ -6,14 +6,10 @@ on:
- opened
- synchronize

env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

permissions: write-all

jobs:
review:
runs-on: ubuntu-latest
permissions: write-all
steps:
- name: Checkout Repo
uses: actions/checkout@v3
@@ -28,4 +24,6 @@ jobs:
- name: Install PyGithub
run: pip install PyGithub
- name: AI Code Review
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: python -m etc.tool.copilot
11 changes: 9 additions & 2 deletions .github/workflows/unittest.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
name: Unittest

on: [push]
on:
pull_request:
types:
- opened
- synchronize
push:
branches:
- 'main'

jobs:
build:
@@ -16,4 +23,4 @@ jobs:
- name: Install requirements
run: pip install -r requirements.txt
- name: Run tests
run: python -m etc.unittest.main
run: python -m etc.unittest
6 changes: 6 additions & 0 deletions etc/unittest/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import unittest
from .asyncio import *
from .backend import *
from .main import *

unittest.main()
57 changes: 57 additions & 0 deletions etc/unittest/asyncio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from .include import DEFAULT_MESSAGES
import asyncio
import nest_asyncio
import unittest
import g4f
from g4f import ChatCompletion
from .mocks import ProviderMock, AsyncProviderMock, AsyncGeneratorProviderMock

class TestChatCompletion(unittest.TestCase):

async def run_exception(self):
return ChatCompletion.create(g4f.models.default, DEFAULT_MESSAGES, AsyncProviderMock)

def test_exception(self):
self.assertRaises(g4f.errors.NestAsyncioError, asyncio.run, self.run_exception())

def test_create(self):
result = ChatCompletion.create(g4f.models.default, DEFAULT_MESSAGES, AsyncProviderMock)
self.assertEqual("Mock",result)

def test_create_generator(self):
result = ChatCompletion.create(g4f.models.default, DEFAULT_MESSAGES, AsyncGeneratorProviderMock)
self.assertEqual("Mock",result)

class TestChatCompletionAsync(unittest.IsolatedAsyncioTestCase):

async def test_base(self):
result = await ChatCompletion.create_async(g4f.models.default, DEFAULT_MESSAGES, ProviderMock)
self.assertEqual("Mock",result)

async def test_async(self):
result = await ChatCompletion.create_async(g4f.models.default, DEFAULT_MESSAGES, AsyncProviderMock)
self.assertEqual("Mock",result)

async def test_create_generator(self):
result = await ChatCompletion.create_async(g4f.models.default, DEFAULT_MESSAGES, AsyncGeneratorProviderMock)
self.assertEqual("Mock",result)

class TestChatCompletionNestAsync(unittest.IsolatedAsyncioTestCase):

def setUp(self) -> None:
nest_asyncio.apply()

async def test_create(self):
result = await ChatCompletion.create_async(g4f.models.default, DEFAULT_MESSAGES, ProviderMock)
self.assertEqual("Mock",result)

async def test_nested(self):
result = ChatCompletion.create(g4f.models.default, DEFAULT_MESSAGES, AsyncProviderMock)
self.assertEqual("Mock",result)

async def test_nested_generator(self):
result = ChatCompletion.create(g4f.models.default, DEFAULT_MESSAGES, AsyncGeneratorProviderMock)
self.assertEqual("Mock",result)

if __name__ == '__main__':
unittest.main()
38 changes: 38 additions & 0 deletions etc/unittest/backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from . import include
import unittest
from unittest.mock import MagicMock
from .mocks import ProviderMock
import g4f
from g4f.gui.server.backend import Backend_Api, get_error_message

class TestBackendApi(unittest.TestCase):

def setUp(self):
self.app = MagicMock()
self.api = Backend_Api(self.app)

def test_version(self):
response = self.api.get_version()
self.assertIn("version", response)
self.assertIn("latest_version", response)

def test_get_models(self):
response = self.api.get_models()
self.assertIsInstance(response, list)
self.assertTrue(len(response) > 0)

def test_get_providers(self):
response = self.api.get_providers()
self.assertIsInstance(response, list)
self.assertTrue(len(response) > 0)

class TestUtilityFunctions(unittest.TestCase):

def test_get_error_message(self):
g4f.debug.last_provider = ProviderMock
exception = Exception("Message")
result = get_error_message(exception)
self.assertEqual("ProviderMock: Exception: Message", result)

if __name__ == '__main__':
unittest.main()
11 changes: 11 additions & 0 deletions etc/unittest/include.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import sys
import pathlib

sys.path.append(str(pathlib.Path(__file__).parent.parent.parent))

import g4f

g4f.debug.logging = False
g4f.debug.version_check = False

DEFAULT_MESSAGES = [{'role': 'user', 'content': 'Hello'}]
78 changes: 20 additions & 58 deletions etc/unittest/main.py
Original file line number Diff line number Diff line change
@@ -1,75 +1,37 @@
import sys
import pathlib
from .include import DEFAULT_MESSAGES
import unittest
from unittest.mock import MagicMock

sys.path.append(str(pathlib.Path(__file__).parent.parent.parent))

import asyncio
import g4f
from g4f import ChatCompletion, get_last_provider
from g4f.gui.server.backend import Backend_Api, get_error_message
from g4f.base_provider import BaseProvider

g4f.debug.logging = False
g4f.debug.version_check = False

class MockProvider(BaseProvider):
working = True

def create_completion(
model, messages, stream, **kwargs
):
yield "Mock"

async def create_async(
model, messages, **kwargs
):
return "Mock"

class TestBackendApi(unittest.TestCase):

def setUp(self):
self.app = MagicMock()
self.api = Backend_Api(self.app)

def test_version(self):
response = self.api.get_version()
self.assertIn("version", response)
self.assertIn("latest_version", response)
from g4f.Provider import RetryProvider
from .mocks import ProviderMock

class TestChatCompletion(unittest.TestCase):

def test_create_default(self):
messages = [{'role': 'user', 'content': 'Hello'}]
result = ChatCompletion.create(g4f.models.default, messages)
result = ChatCompletion.create(g4f.models.default, DEFAULT_MESSAGES)
if "Good" not in result and "Hi" not in result:
self.assertIn("Hello", result)

def test_get_last_provider(self):
messages = [{'role': 'user', 'content': 'Hello'}]
ChatCompletion.create(g4f.models.default, messages, MockProvider)
self.assertEqual(get_last_provider(), MockProvider)


def test_bing_provider(self):
messages = [{'role': 'user', 'content': 'Hello'}]
provider = g4f.Provider.Bing
result = ChatCompletion.create(g4f.models.default, messages, provider)
result = ChatCompletion.create(g4f.models.default, DEFAULT_MESSAGES, provider)
self.assertIn("Bing", result)

class TestChatCompletionAsync(unittest.IsolatedAsyncioTestCase):

async def test_async(self):
messages = [{'role': 'user', 'content': 'Hello'}]
result = await ChatCompletion.create_async(g4f.models.default, messages, MockProvider)
self.assertEqual("Mock", result)
class TestGetLastProvider(unittest.TestCase):

class TestUtilityFunctions(unittest.TestCase):

def test_get_error_message(self):
g4f.debug.last_provider = g4f.Provider.Bing
exception = Exception("Message")
result = get_error_message(exception)
self.assertEqual("Bing: Exception: Message", result)
def test_get_last_provider(self):
ChatCompletion.create(g4f.models.default, DEFAULT_MESSAGES, ProviderMock)
self.assertEqual(get_last_provider(), ProviderMock)

def test_get_last_provider_retry(self):
ChatCompletion.create(g4f.models.default, DEFAULT_MESSAGES, RetryProvider([ProviderMock]))
self.assertEqual(get_last_provider(), ProviderMock)

def test_get_last_provider_async(self):
coroutine = ChatCompletion.create_async(g4f.models.default, DEFAULT_MESSAGES, ProviderMock)
asyncio.run(coroutine)
self.assertEqual(get_last_provider(), ProviderMock)

if __name__ == '__main__':
unittest.main()
25 changes: 25 additions & 0 deletions etc/unittest/mocks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from g4f.Provider.base_provider import AbstractProvider, AsyncProvider, AsyncGeneratorProvider

class ProviderMock(AbstractProvider):
working = True

def create_completion(
model, messages, stream, **kwargs
):
yield "Mock"

class AsyncProviderMock(AsyncProvider):
working = True

async def create_async(
model, messages, **kwargs
):
return "Mock"

class AsyncGeneratorProviderMock(AsyncGeneratorProvider):
working = True

async def create_async_generator(
model, messages, stream, **kwargs
):
yield "Mock"
13 changes: 5 additions & 8 deletions g4f/Provider/Bing.py
Original file line number Diff line number Diff line change
@@ -64,12 +64,7 @@ def create_async_generator(
prompt = messages[-1]["content"]
context = create_context(messages[:-1])

if not cookies:
cookies = Defaults.cookies
else:
for key, value in Defaults.cookies.items():
if key not in cookies:
cookies[key] = value
cookies = {**Defaults.cookies, **cookies} if cookies else Defaults.cookies

gpt4_turbo = True if model.startswith("gpt-4-turbo") else False

@@ -207,10 +202,12 @@ def create_message(
request_id = str(uuid.uuid4())
struct = {
'arguments': [{
'source': 'cib', 'optionsSets': options_sets,
'source': 'cib',
'optionsSets': options_sets,
'allowedMessageTypes': Defaults.allowedMessageTypes,
'sliceIds': Defaults.sliceIds,
'traceId': os.urandom(16).hex(), 'isStartOfSession': True,
'traceId': os.urandom(16).hex(),
'isStartOfSession': True,
'requestId': request_id,
'message': {
**Defaults.location,
41 changes: 23 additions & 18 deletions g4f/Provider/base_provider.py
Original file line number Diff line number Diff line change
@@ -5,8 +5,8 @@
from concurrent.futures import ThreadPoolExecutor
from abc import abstractmethod
from inspect import signature, Parameter
from .helper import get_event_loop, get_cookies, format_prompt
from ..typing import CreateResult, AsyncResult, Messages
from .helper import get_cookies, format_prompt
from ..typing import CreateResult, AsyncResult, Messages, Union
from ..base_provider import BaseProvider
from ..errors import NestAsyncioError

@@ -20,6 +20,17 @@
if isinstance(asyncio.get_event_loop_policy(), asyncio.WindowsProactorEventLoopPolicy):
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())

def get_running_loop() -> Union[AbstractEventLoop, None]:
try:
loop = asyncio.get_running_loop()
if not hasattr(loop.__class__, "_nest_patched"):
raise NestAsyncioError(
'Use "create_async" instead of "create" function in a running event loop. Or use "nest_asyncio" package.'
)
return loop
except RuntimeError:
pass

class AbstractProvider(BaseProvider):
"""
Abstract class for providing asynchronous functionality to derived classes.
@@ -56,7 +67,7 @@ def create_func() -> str:

return await asyncio.wait_for(
loop.run_in_executor(executor, create_func),
timeout=kwargs.get("timeout", 0)
timeout=kwargs.get("timeout")
)

@classmethod
@@ -118,14 +129,7 @@ def create_completion(
Returns:
CreateResult: The result of the completion creation.
"""
try:
loop = asyncio.get_running_loop()
if not hasattr(loop.__class__, "_nest_patched"):
raise NestAsyncioError(
'Use "create_async" instead of "create" function in a running event loop. Or use "nest_asyncio" package.'
)
except RuntimeError:
pass
get_running_loop()
yield asyncio.run(cls.create_async(model, messages, **kwargs))

@staticmethod
@@ -180,15 +184,12 @@ def create_completion(
Returns:
CreateResult: The result of the streaming completion creation.
"""
try:
loop = asyncio.get_running_loop()
if not hasattr(loop.__class__, "_nest_patched"):
raise NestAsyncioError(
'Use "create_async" instead of "create" function in a running event loop. Or use "nest_asyncio" package.'
)
except RuntimeError:
loop = get_running_loop()
new_loop = False
if not loop:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
new_loop = True

generator = cls.create_async_generator(model, messages, stream=stream, **kwargs)
gen = generator.__aiter__()
@@ -199,6 +200,10 @@ def create_completion(
except StopAsyncIteration:
break

if new_loop:
loop.close()
asyncio.set_event_loop(None)

@classmethod
async def create_async(
cls,
3 changes: 1 addition & 2 deletions g4f/gui/server/backend.py
Original file line number Diff line number Diff line change
@@ -2,15 +2,14 @@
import json
from flask import request, Flask
from typing import Generator
from g4f import debug, version, models
from g4f import version, models
from g4f import _all_models, get_last_provider, ChatCompletion
from g4f.image import is_allowed_extension, to_image
from g4f.errors import VersionNotFoundError
from g4f.Provider import __providers__
from g4f.Provider.bing.create_images import patch_provider
from .internet import get_search_message

debug.logging = True

class Backend_Api:
"""
14 changes: 7 additions & 7 deletions g4f/image.py
Original file line number Diff line number Diff line change
@@ -112,7 +112,7 @@ def get_orientation(image: Image.Image) -> int:
"""
exif_data = image.getexif() if hasattr(image, 'getexif') else image._getexif()
if exif_data is not None:
orientation = exif_data.get(274) # 274 corresponds to the orientation tag in EXIF
orientation = exif_data.get(274) # 274 corresponds to the orientation tag in EXIF
if orientation is not None:
return orientation

@@ -156,23 +156,23 @@ def to_base64(image: Image.Image, compression_rate: float) -> str:
image.save(output_buffer, format="JPEG", quality=int(compression_rate * 100))
return base64.b64encode(output_buffer.getvalue()).decode()

def format_images_markdown(images, prompt: str, preview: str="{image}?w=200&h=200") -> str:
def format_images_markdown(images, alt: str, preview: str="{image}?w=200&h=200") -> str:
"""
Formats the given images as a markdown string.
Args:
images: The images to format.
prompt (str): The prompt for the images.
alt (str): The alt for the images.
preview (str, optional): The preview URL format. Defaults to "{image}?w=200&h=200".
Returns:
str: The formatted markdown string.
"""
if isinstance(images, list):
images = [f"[![#{idx+1} {prompt}]({preview.replace('{image}', image)})]({image})" for idx, image in enumerate(images)]
images = "\n".join(images)
if isinstance(images, str):
images = f"[![{alt}]({preview.replace('{image}', images)})]({images})"
else:
images = f"[![{prompt}]({images})]({images})"
images = [f"[![#{idx+1} {alt}]({preview.replace('{image}', image)})]({image})" for idx, image in enumerate(images)]
images = "\n".join(images)
start_flag = "<!-- generated images start -->\n"
end_flag = "<!-- generated images end -->\n"
return f"\n{start_flag}{images}\n{end_flag}\n"
7 changes: 7 additions & 0 deletions g4f/typing.py
Original file line number Diff line number Diff line change
@@ -18,7 +18,14 @@
'AsyncGenerator',
'Generator',
'Tuple',
'Union',
'List',
'Dict',
'Type',
'TypedDict',
'SHA256',
'CreateResult',
'AsyncResult',
'Messages',
'ImageType'
]