Skip to content

Commit

Permalink
feat: v0.4.0 release
Browse files Browse the repository at this point in the history
- fix: update index of GeneratedImage.alt
- fix: now secure_1psidts will only be passed to cookies if it is not empty
- feat: add an optional parameter to Image.save() to skip images with invalid file names
- feat: Image.save() now returns saved file path on success
- refactor: code reorganization and cleanup
- test: updated unit tests
  • Loading branch information
HanaokaYuzu committed Feb 29, 2024
1 parent b4f8e92 commit b5c58a8
Show file tree
Hide file tree
Showing 11 changed files with 157 additions and 142 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,13 @@ Note: by default, when asked to send images (like the previous example), Gemini

### Save images to local files

You can save images returned from Gemini to local files under `/temp` by calling `Image.save()`. Optionally, you can specify the file path and file name by passing `path` and `filename` arguments to the function. Works for both `WebImage` and `GeneratedImage`.
You can save images returned from Gemini to local files under `/temp` by calling `Image.save()`. Optionally, you can specify the file path and file name by passing `path` and `filename` arguments to the function and skip images with invalid file names by passing `skip_invalid_filename=True`. Works for both `WebImage` and `GeneratedImage`.

```python
async def main():
response = await client.generate_content("Generate some pictures of cats")
for i, image in enumerate(response.images):
await image.save(path="temp/", filename=f"cat_{i}.png")
await image.save(path="temp/", filename=f"cat_{i}.png", verbose=True)

asyncio.run(main())
```
Expand Down
1 change: 1 addition & 0 deletions src/gemini/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .client import GeminiClient, ChatSession # noqa: F401
from .exceptions import * # noqa: F401, F403
from .types import * # noqa: F401, F403
26 changes: 10 additions & 16 deletions src/gemini/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,9 @@
from httpx import AsyncClient, ReadTimeout
from loguru import logger

from .consts import HEADERS
from .types import (
WebImage,
GeneratedImage,
Candidate,
ModelOutput,
AuthError,
APIError,
GeminiError,
TimeoutError,
)
from .types import WebImage, GeneratedImage, Candidate, ModelOutput
from .exceptions import APIError, AuthError, TimeoutError, GeminiError
from .constant import HEADERS


def running(func) -> callable:
Expand Down Expand Up @@ -71,10 +63,7 @@ def __init__(
secure_1psidts: Optional[str] = None,
proxy: Optional[dict] = None,
):
self.cookies = {
"__Secure-1PSID": secure_1psid,
"__Secure-1PSIDTS": secure_1psidts,
}
self.cookies = {"__Secure-1PSID": secure_1psid}
self.proxy = proxy
self.client: AsyncClient | None = None
self.access_token: Optional[str] = None
Expand All @@ -83,6 +72,9 @@ def __init__(
self.close_delay: float = 300
self.close_task: Task | None = None

if secure_1psidts:
self.cookies["__Secure-1PSIDTS"] = secure_1psidts

async def init(
self, timeout: float = 30, auto_close: bool = False, close_delay: float = 300
) -> None:
Expand Down Expand Up @@ -248,7 +240,9 @@ async def generate_content(
GeneratedImage(
url=image[0][3][3],
title=f"[Generated Image {image[3][6]}]",
alt=image[3][5][i],
alt=len(image[3][5]) > i
and image[3][5][i]
or image[3][5][0],
cookies=self.cookies,
)
for i, image in enumerate(candidate[12][7][0])
Expand Down
File renamed without changes.
30 changes: 30 additions & 0 deletions src/gemini/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
class AuthError(Exception):
"""
Exception for authentication errors caused by invalid credentials/cookies.
"""

pass


class APIError(Exception):
"""
Exception for package-level errors which need to be fixed in the future development (e.g. validation errors).
"""

pass


class GeminiError(Exception):
"""
Exception for errors returned from Gemini server which are not handled by the package.
"""

pass


class TimeoutError(GeminiError):
"""
Exception for request timeouts.
"""

pass
3 changes: 3 additions & 0 deletions src/gemini/types/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .image import Image, WebImage, GeneratedImage # noqa: F401
from .candidate import Candidate # noqa: F401
from .modeloutput import ModelOutput # noqa: F401
35 changes: 35 additions & 0 deletions src/gemini/types/candidate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from pydantic import BaseModel

from .image import Image, WebImage, GeneratedImage


class Candidate(BaseModel):
"""
A single reply candidate object in the model output. A full response from Gemini usually contains multiple reply candidates.
Parameters
----------
rcid: `str`
Reply candidate ID to build the metadata
text: `str`
Text output
web_images: `list[WebImage]`, optional
List of web images in reply, can be empty.
generated_images: `list[GeneratedImage]`, optional
List of generated images in reply, can be empty
"""

rcid: str
text: str
web_images: list[WebImage] = []
generated_images: list[GeneratedImage] = []

def __str__(self):
return self.text

def __repr__(self):
return f"Candidate(rcid='{self.rcid}', text='{len(self.text) <= 20 and self.text or self.text[:20] + '...'}', images={self.images})"

@property
def images(self) -> list[Image]:
return self.web_images + self.generated_images
149 changes: 28 additions & 121 deletions src/gemini/types.py → src/gemini/types/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,9 @@ async def save(
path: str = "temp",
filename: str | None = None,
cookies: dict | None = None,
verbose: bool = False
) -> None:
verbose: bool = False,
skip_invalid_filename: bool = False,
) -> str | None:
"""
Save the image to disk.
Expand All @@ -46,22 +47,27 @@ async def save(
path: `str`, optional
Path to save the image, by default will save to ./temp
filename: `str`, optional
Filename to save the image, by default will use the original filename from the URL
File name to save the image, by default will use the original file name from the URL
cookies: `dict`, optional
Cookies used for requesting the content of the image
verbose : `bool`, optional
If True, print the path of the saved file, by default False
If True, print the path of the saved file or warning for invalid file name, by default False
skip_invalid_filename: `bool`, optional
If True, will only save the image if the file name and extension are valid, by default False
Returns
-------
`str | None`
Absolute path of the saved image if successful, None if filename is invalid and `skip_invalid_filename` is True
"""
filename = filename or self.url.split("/")[-1].split("?")[0]
try:
filename = (
filename
or (
re.search(r"^(.*\.\w+)", self.url.split("/")[-1])
or re.search(r"^(.*)\?", self.url.split("/")[-1])
).group()
)
filename = re.search(r"^(.*\.\w+)", filename).group()
except AttributeError:
filename = self.url.split("/")[-1]
if verbose:
logger.warning(f"Invalid filename: {filename}")
if skip_invalid_filename:
return None

async with AsyncClient(follow_redirects=True, cookies=cookies) as client:
response = await client.get(self.url)
Expand All @@ -80,6 +86,8 @@ async def save(

if verbose:
logger.info(f"Image saved as {dest.resolve()}")

return dest.resolve()
else:
raise HTTPError(
f"Error downloading image: {response.status_code} {response.reason_phrase}"
Expand Down Expand Up @@ -110,129 +118,28 @@ class GeneratedImage(Image):
@field_validator("cookies")
@classmethod
def validate_cookies(cls, v: dict) -> dict:
if "__Secure-1PSID" not in v or "__Secure-1PSIDTS" not in v:
if len(v) == 0:
raise ValueError(
"Cookies must contain '__Secure-1PSID' and '__Secure-1PSIDTS'"
"GeneratedImage is designed to be initiated with same cookies as GeminiClient."
)
return v

# @override
async def save(self, path: str = "temp/", filename: str = None) -> None:
async def save(self, **kwargs) -> None:
"""
Save the image to disk.
Parameters
----------
path: `str`
Path to save the image
filename: `str`, optional
Filename to save the image, generated images are always in .png format, but file extension will not be included in the URL.
And since the URL ends with a long hash, by default will use timestamp + end of the hash as the filename
**kwargs: `dict`, optional
Other arguments to pass to `Image.save`
"""
await super().save(
path,
filename
filename=kwargs.pop("filename", None)
or f"{datetime.now().strftime('%Y%m%d%H%M%S')}_{self.url[-10:]}.png",
self.cookies,
cookies=self.cookies,
**kwargs,
)


class Candidate(BaseModel):
"""
A single reply candidate object in the model output. A full response from Gemini usually contains multiple reply candidates.
Parameters
----------
rcid: `str`
Reply candidate ID to build the metadata
text: `str`
Text output
web_images: `list[WebImage]`, optional
List of web images in reply, can be empty.
generated_images: `list[GeneratedImage]`, optional
List of generated images in reply, can be empty
"""

rcid: str
text: str
web_images: list[WebImage] = []
generated_images: list[GeneratedImage] = []

def __str__(self):
return self.text

def __repr__(self):
return f"Candidate(rcid='{self.rcid}', text='{len(self.text) <= 20 and self.text or self.text[:20] + '...'}', images={self.images})"

@property
def images(self) -> list[Image]:
return self.web_images + self.generated_images


class ModelOutput(BaseModel):
"""
Classified output from gemini.google.com
Parameters
----------
metadata: `list[str]`
List of chat metadata `[cid, rid, rcid]`, can be shorter than 3 elements, like `[cid, rid]` or `[cid]` only
candidates: `list[Candidate]`
List of all candidates returned from gemini
chosen: `int`, optional
Index of the chosen candidate, by default will choose the first one
"""

metadata: list[str]
candidates: list[Candidate]
chosen: int = 0

def __str__(self):
return self.text

def __repr__(self):
return f"ModelOutput(metadata={self.metadata}, chosen={self.chosen}, candidates={self.candidates})"

@property
def text(self) -> str:
return self.candidates[self.chosen].text

@property
def images(self) -> list[Image]:
return self.candidates[self.chosen].images

@property
def rcid(self) -> str:
return self.candidates[self.chosen].rcid


class AuthError(Exception):
"""
Exception for authentication errors caused by invalid credentials/cookies.
"""

pass


class APIError(Exception):
"""
Exception for package-level errors which need to be fixed in the future development (e.g. validation errors).
"""

pass


class GeminiError(Exception):
"""
Exception for errors returned from Gemini server which are not handled by the package.
"""

pass


class TimeoutError(GeminiError):
"""
Exception for request timeouts.
"""

pass
41 changes: 41 additions & 0 deletions src/gemini/types/modeloutput.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from pydantic import BaseModel

from .image import Image
from .candidate import Candidate


class ModelOutput(BaseModel):
"""
Classified output from gemini.google.com
Parameters
----------
metadata: `list[str]`
List of chat metadata `[cid, rid, rcid]`, can be shorter than 3 elements, like `[cid, rid]` or `[cid]` only
candidates: `list[Candidate]`
List of all candidates returned from gemini
chosen: `int`, optional
Index of the chosen candidate, by default will choose the first one
"""

metadata: list[str]
candidates: list[Candidate]
chosen: int = 0

def __str__(self):
return self.text

def __repr__(self):
return f"ModelOutput(metadata={self.metadata}, chosen={self.chosen}, candidates={self.candidates})"

@property
def text(self) -> str:
return self.candidates[self.chosen].text

@property
def images(self) -> list[Image]:
return self.candidates[self.chosen].images

@property
def rcid(self) -> str:
return self.candidates[self.chosen].rcid
6 changes: 5 additions & 1 deletion tests/test_client_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,11 @@ async def test_reply_candidates(self):
response = await chat.send_message(
"What's the best Japanese dish? Recommend one only."
)
self.assertTrue(len(response.candidates) > 1)

if len(response.candidates) == 1:
logger.debug(response.candidates[0])
self.skipTest("Only one candidate was returned. Test skipped")

for candidate in response.candidates:
logger.debug(candidate)

Expand Down
Loading

0 comments on commit b5c58a8

Please sign in to comment.