-
-
Notifications
You must be signed in to change notification settings - Fork 13.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1820 from hlohaus/bugfix
Add ReplicateImage Provider, Fix BingCreateImages Provider
- Loading branch information
Showing
10 changed files
with
260 additions
and
67 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
from __future__ import annotations | ||
|
||
import random | ||
import asyncio | ||
|
||
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin | ||
from ..typing import AsyncResult, Messages | ||
from ..requests import StreamSession, raise_for_status | ||
from ..image import ImageResponse | ||
from ..errors import ResponseError | ||
|
||
class ReplicateImage(AsyncGeneratorProvider, ProviderModelMixin): | ||
url = "https://replicate.com" | ||
working = True | ||
default_model = 'stability-ai/sdxl' | ||
default_versions = [ | ||
"39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b", | ||
"2b017d9b67edd2ee1401238df49d75da53c523f36e363881e057f5dc3ed3c5b2" | ||
] | ||
|
||
@classmethod | ||
async def create_async_generator( | ||
cls, | ||
model: str, | ||
messages: Messages, | ||
**kwargs | ||
) -> AsyncResult: | ||
yield await cls.create_async(messages[-1]["content"], model, **kwargs) | ||
|
||
@classmethod | ||
async def create_async( | ||
cls, | ||
prompt: str, | ||
model: str, | ||
api_key: str = None, | ||
proxy: str = None, | ||
timeout: int = 180, | ||
version: str = None, | ||
extra_data: dict = {}, | ||
**kwargs | ||
) -> ImageResponse: | ||
headers = { | ||
'Accept-Encoding': 'gzip, deflate, br', | ||
'Accept-Language': 'en-US', | ||
'Connection': 'keep-alive', | ||
'Origin': cls.url, | ||
'Referer': f'{cls.url}/', | ||
'Sec-Fetch-Dest': 'empty', | ||
'Sec-Fetch-Mode': 'cors', | ||
'Sec-Fetch-Site': 'same-site', | ||
'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36', | ||
'sec-ch-ua': '"Google Chrome";v="119", "Chromium";v="119", "Not?A_Brand";v="24"', | ||
'sec-ch-ua-mobile': '?0', | ||
'sec-ch-ua-platform': '"macOS"', | ||
} | ||
if version is None: | ||
version = random.choice(cls.default_versions) | ||
if api_key is not None: | ||
headers["Authorization"] = f"Bearer {api_key}" | ||
async with StreamSession( | ||
proxies={"all": proxy}, | ||
headers=headers, | ||
timeout=timeout | ||
) as session: | ||
data = { | ||
"input": { | ||
"prompt": prompt, | ||
**extra_data | ||
}, | ||
"version": version | ||
} | ||
if api_key is None: | ||
data["model"] = cls.get_model(model) | ||
url = "https://homepage.replicate.com/api/prediction" | ||
else: | ||
url = "https://api.replicate.com/v1/predictions" | ||
async with session.post(url, json=data) as response: | ||
await raise_for_status(response) | ||
result = await response.json() | ||
if "id" not in result: | ||
raise ResponseError(f"Invalid response: {result}") | ||
while True: | ||
if api_key is None: | ||
url = f"https://homepage.replicate.com/api/poll?id={result['id']}" | ||
else: | ||
url = f"https://api.replicate.com/v1/predictions/{result['id']}" | ||
async with session.get(url) as response: | ||
await raise_for_status(response) | ||
result = await response.json() | ||
if "status" not in result: | ||
raise ResponseError(f"Invalid response: {result}") | ||
if result["status"] == "succeeded": | ||
images = result['output'] | ||
images = images[0] if len(images) == 1 else images | ||
return ImageResponse(images, prompt) | ||
await asyncio.sleep(0.5) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
from __future__ import annotations | ||
|
||
import asyncio | ||
|
||
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin | ||
from ..helper import format_prompt, filter_none | ||
from ...typing import AsyncResult, Messages | ||
from ...requests import StreamSession, raise_for_status | ||
from ...image import ImageResponse | ||
from ...errors import ResponseError, MissingAuthError | ||
|
||
class Replicate(AsyncGeneratorProvider, ProviderModelMixin): | ||
url = "https://replicate.com" | ||
working = True | ||
default_model = "mistralai/mixtral-8x7b-instruct-v0.1" | ||
api_base = "https://api.replicate.com/v1/models/" | ||
|
||
@classmethod | ||
async def create_async_generator( | ||
cls, | ||
model: str, | ||
messages: Messages, | ||
api_key: str = None, | ||
proxy: str = None, | ||
timeout: int = 180, | ||
system_prompt: str = None, | ||
max_new_tokens: int = None, | ||
temperature: float = None, | ||
top_p: float = None, | ||
top_k: float = None, | ||
stop: list = None, | ||
extra_data: dict = {}, | ||
headers: dict = {}, | ||
**kwargs | ||
) -> AsyncResult: | ||
model = cls.get_model(model) | ||
if api_key is None: | ||
raise MissingAuthError("api_key is missing") | ||
headers["Authorization"] = f"Bearer {api_key}" | ||
async with StreamSession( | ||
proxies={"all": proxy}, | ||
headers=headers, | ||
timeout=timeout | ||
) as session: | ||
data = { | ||
"stream": True, | ||
"input": { | ||
"prompt": format_prompt(messages), | ||
**filter_none( | ||
system_prompt=system_prompt, | ||
max_new_tokens=max_new_tokens, | ||
temperature=temperature, | ||
top_p=top_p, | ||
top_k=top_k, | ||
stop_sequences=",".join(stop) if stop else None | ||
), | ||
**extra_data | ||
}, | ||
} | ||
url = f"{cls.api_base.rstrip('/')}/{model}/predictions" | ||
async with session.post(url, json=data) as response: | ||
await raise_for_status(response) | ||
result = await response.json() | ||
if "id" not in result: | ||
raise ResponseError(f"Invalid response: {result}") | ||
async with session.get(result["urls"]["stream"], headers={"Accept": "text/event-stream"}) as response: | ||
await raise_for_status(response) | ||
event = None | ||
async for line in response.iter_lines(): | ||
if line.startswith(b"event: "): | ||
event = line[7:] | ||
elif event == b"output": | ||
if line.startswith(b"data: "): | ||
yield line[6:].decode() | ||
elif not line.startswith(b"id: "): | ||
continue#yield "+"+line.decode() | ||
elif event == b"done": | ||
break |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.