Skip to content

Commit

Permalink
Update abstra-lib
Browse files Browse the repository at this point in the history
  • Loading branch information
abstra-bot committed Nov 19, 2024
1 parent f64ce2d commit 649b9af
Show file tree
Hide file tree
Showing 229 changed files with 654 additions and 583 deletions.
147 changes: 93 additions & 54 deletions abstra_internals/controllers/sdk_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
from typing import Dict, List, Optional, Union

import pypdfium2 as pdfium
from PIL.Image import Image

import abstra_internals.utils.b64 as b64
from abstra_internals.repositories.ai import AiApiHttpClient
from abstra_internals.utils.b64 import is_base_64, to_base64
from abstra_internals.utils.string import to_snake_case
from abstra_internals.widgets.response_types import FileResponse

Prompt = Union[str, io.IOBase, pathlib.Path]
Prompt = Union[str, io.IOBase, pathlib.Path, FileResponse]
Format = Dict[str, object]


Expand Down Expand Up @@ -43,14 +45,97 @@ def _extract_pdf_images(self, file: Prompt) -> List[io.BytesIO]:
images.append(image_io)
return images

def _prompt_is_valid_file(self, prompt: Prompt) -> bool:
if isinstance(prompt, (io.IOBase, pathlib.Path)) or is_base_64(prompt):
return True
def _make_image_url_message(self, url: str):
return {
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": url,
},
},
],
}

def _make_text_message(self, text: str):
return {
"role": "user",
"content": text,
}

def _try_extract_images(self, input: io.IOBase) -> Optional[List[io.BytesIO]]:
try:
return pathlib.Path(prompt).exists()
images = self._extract_pdf_images(input)
return images
except Exception:
return False
return None

def _make_messages(self, prompt: Prompt) -> List:
if isinstance(prompt, pathlib.Path):
with prompt.open("rb") as f:
if images := self._try_extract_images(f):
return [
self._make_image_url_message(b64.encode_base_64(image))
for image in images
]

encoded_str = b64.encode_base_64(f)
return [self._make_image_url_message(encoded_str)]

if isinstance(prompt, FileResponse):
file = prompt.file
if images := self._try_extract_images(file):
return [
self._make_image_url_message(b64.encode_base_64(image))
for image in images
]

encoded_str = b64.encode_base_64(file)
return [self._make_image_url_message(encoded_str)]

if isinstance(prompt, io.IOBase):
prompt.seek(0)
if images := self._try_extract_images(prompt):
return [
self._make_image_url_message(b64.encode_base_64(image))
for image in images
]
encoded_str = b64.encode_base_64(prompt)
return [self._make_image_url_message(encoded_str)]

if isinstance(prompt, str) and (
b64.is_base_64(prompt) or prompt.startswith("http")
):
if prompt.endswith(".pdf"):
raise ValueError("PDF URLs are not supported")

return [self._make_image_url_message(prompt)]

try:
if isinstance(prompt, str) and pathlib.Path(prompt).exists():
with open(prompt, "rb") as f:
if images := self._try_extract_images(f):
return [
self._make_image_url_message(b64.encode_base_64(image))
for image in images
]
encoded_str = b64.encode_base_64(f)
return [self._make_image_url_message(encoded_str)]
except OSError: # Path contructor can raise OSError on long strings
pass

if isinstance(prompt, Image):
image_io = io.BytesIO()
prompt.save(image_io, format="PNG")
image_io.seek(0)
encoded_str = b64.encode_base_64(image_io)
return [self._make_image_url_message(encoded_str)]

if isinstance(prompt, str):
return [self._make_text_message(prompt)]

raise ValueError(f"Invalid prompt: {prompt}")

def prompt(
self,
Expand All @@ -70,53 +155,7 @@ def prompt(
)

for prompt in prompts:
if self._prompt_is_valid_file(prompt):
try:
images = self._extract_pdf_images(prompt)
except pdfium.PdfiumError:
# If the file is not a PDF, read it as a single image
images = [prompt]
except FileNotFoundError:
raise FileNotFoundError(f"File not found: {prompt}")
except Exception as e:
raise Exception(f"Error reading file: {e}")

for image in images:
base_64_image = to_base64(image)
messages.append(
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": base_64_image,
},
},
],
}
)
elif isinstance(prompt, str) and is_base_64(prompt):
messages.append(
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": prompt,
},
},
],
}
)
elif isinstance(prompt, str):
messages.append(
{
"role": "user",
"content": prompt,
}
)
messages.extend(self._make_messages(prompt))

tools = []
if format:
Expand Down
146 changes: 146 additions & 0 deletions abstra_internals/controllers/sdk_ai_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import unittest
from io import BytesIO
from pathlib import Path
from unittest.mock import MagicMock, patch

from PIL import Image

from abstra_internals.controllers.sdk_ai import AiSDKController
from abstra_internals.widgets.response_types import FileResponse


class TestAiSDKController(unittest.TestCase):
def setUp(self):
# Mock the AiApiHttpClient to isolate tests from external API calls
self.mock_ai_client = MagicMock()
self.controller = AiSDKController(self.mock_ai_client)

def test_extract_pdf_images(self):
# Mock pdfium.PdfDocument to return a mock page
mock_page = MagicMock()
mock_page.render.return_value.to_pil.return_value = Image.new("RGB", (100, 100))
mock_pdf = MagicMock()
mock_pdf.__iter__.return_value = [mock_page]

with patch("pypdfium2.PdfDocument", return_value=mock_pdf):
images = self.controller._extract_pdf_images(mock_pdf)

self.assertEqual(len(images), 1) # One image is expected
self.assertTrue(
isinstance(images[0], BytesIO)
) # Image should be in BytesIO format

def test_make_image_url_message(self):
url = "http://example.com/image.png"
expected_message = {
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": url,
},
},
],
}

result = self.controller._make_image_url_message(url)
self.assertEqual(result, expected_message)

def test_make_text_message(self):
text = "Some text"
expected_message = {
"role": "user",
"content": text,
}

result = self.controller._make_text_message(text)
self.assertEqual(result, expected_message)

def test_make_messages_from_str(self):
text = "This is a prompt"
result = self.controller._make_messages(text)
self.assertEqual(len(result), 1)
self.assertEqual(result[0]["role"], "user")
self.assertEqual(result[0]["content"], text)

def test_make_messages_from_image(self):
# Create a dummy image to simulate the prompt
img = Image.new("RGB", (100, 100))
image_io = BytesIO()
img.save(image_io, format="PNG")
image_io.seek(0)

result = self.controller._make_messages(image_io)
self.assertEqual(len(result), 1)
self.assertIn("type", result[0]["content"][0])
self.assertEqual(result[0]["content"][0]["type"], "image_url")

def test_make_messages_from_path(self):
file_path = Path("test.pdf")
file_path.write_text("dummy content") # Creating a dummy file

with patch("builtins.open", MagicMock(return_value=BytesIO(b"fake content"))):
result = self.controller._make_messages(file_path)

self.assertEqual(len(result), 1)
self.assertIn("type", result[0]["content"][0])

def test_prompt_with_image_file(self):
# Mock AI client response
self.mock_ai_client.prompt.return_value = {"content": "Mocked AI response"}

# Create a dummy image
img = Image.new("RGB", (100, 100))
image_io = BytesIO()
img.save(image_io, format="PNG")
image_io.seek(0)

result = self.controller.prompt(
prompts=[image_io],
instructions=["This is a test instruction."],
format=None,
temperature=0.7,
)

self.assertEqual(result, "Mocked AI response")

def test_prompt_with_file_response(self):
# Mock AI client response
self.mock_ai_client.prompt.return_value = {"content": "Mocked AI response"}

# save dummy file
file_path = Path("test.pdf")
file_path.write_text("dummy content")

class MockFileResponse(FileResponse):
def __init__(self, url):
self._url = url

@property
def name(self):
return self.path.name

@property
def content(self):
return self.path.read_bytes()

@property
def file(self):
return self.path.open("rb")

@property
def path(self):
return file_path

# Mock FileResponse
file_response = MockFileResponse(url=file_path.as_posix())

result = self.controller.prompt(
prompts=[file_response],
instructions=["Test instruction."],
format=None,
temperature=0.7,
)

self.assertEqual(result, "Mocked AI response")
8 changes: 1 addition & 7 deletions abstra_internals/interface/sdk/ai.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,7 @@
import io
import pathlib
from typing import Dict, List, Optional, TypeVar, Union

from abstra_internals.controllers.execution_store import ExecutionStore

Prompt = Union[str, io.IOBase, pathlib.Path]
InputFile = Union[str, pathlib.Path, io.IOBase]
Format = Dict[str, object]

from abstra_internals.controllers.sdk_ai import Format, Prompt

T = TypeVar("T")

Expand Down
Loading

0 comments on commit 649b9af

Please sign in to comment.