Skip to content

Commit

Permalink
Rename upload_file to base64_encode_file
Browse files Browse the repository at this point in the history
Remove unused output_file_prefix parameter

Add test coverage for base64_encode_file

Signed-off-by: Mattt Zmuda <mattt@replicate.com>
  • Loading branch information
mattt committed Aug 22, 2024
1 parent 3cc1974 commit e19cb78
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 30 deletions.
10 changes: 5 additions & 5 deletions replicate/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing_extensions import Unpack, deprecated

from replicate.account import Account
from replicate.file import upload_file
from replicate.file import base64_encode_file
from replicate.json import encode_json
from replicate.pagination import Page
from replicate.prediction import (
Expand Down Expand Up @@ -424,7 +424,7 @@ def create(
if input is not None:
input = encode_json(
input,
upload_file=upload_file
upload_file=base64_encode_file
if file_encoding_strategy == "base64"
else lambda file: self._client.files.create(file).urls["get"],
)
Expand All @@ -451,7 +451,7 @@ async def async_create(
if input is not None:
input = encode_json(
input,
upload_file=upload_file
upload_file=base64_encode_file
if file_encoding_strategy == "base64"
else lambda file: asyncio.get_event_loop()
.run_until_complete(self._client.files.async_create(file))
Expand Down Expand Up @@ -489,7 +489,7 @@ def create(
if input is not None:
input = encode_json(
input,
upload_file=upload_file
upload_file=base64_encode_file
if file_encoding_strategy == "base64"
else lambda file: self._client.files.create(file).urls["get"],
)
Expand Down Expand Up @@ -519,7 +519,7 @@ async def async_create(
if input is not None:
input = encode_json(
input,
upload_file=upload_file
upload_file=base64_encode_file
if file_encoding_strategy == "base64"
else lambda file: asyncio.get_event_loop()
.run_until_complete(self._client.files.async_create(file))
Expand Down
21 changes: 5 additions & 16 deletions replicate/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import pathlib
from typing import Any, BinaryIO, Dict, List, Optional, TypedDict, Union

import httpx
from typing_extensions import NotRequired, Unpack

from replicate.resource import Namespace, Resource
Expand Down Expand Up @@ -171,33 +170,23 @@ def _json_to_file(json: Dict[str, Any]) -> File: # pylint: disable=redefined-ou
return File(**json)


def upload_file(file: io.IOBase, output_file_prefix: Optional[str] = None) -> str:
def base64_encode_file(file: io.IOBase) -> str:
"""
Upload a file to the server.
Base64 encode a file.
Args:
file: A file handle to upload.
output_file_prefix: A string to prepend to the output file name.
Returns:
str: A URL to the uploaded file.
str: A base64-encoded data URI.
"""
# Lifted straight from cog.files

file.seek(0)

if output_file_prefix is not None:
name = getattr(file, "name", "output")
url = output_file_prefix + os.path.basename(name)
resp = httpx.put(url, files={"file": file}, timeout=None) # type: ignore
resp.raise_for_status()

return url

body = file.read()

# Ensure the file handle is in bytes
body = body.encode("utf-8") if isinstance(body, str) else body
encoded_body = base64.b64encode(body).decode("utf-8")
# Use getattr to avoid mypy complaints about io.IOBase having no attribute name

mime_type = (
mimetypes.guess_type(getattr(file, "name", ""))[0] or "application/octet-stream"
)
Expand Down
6 changes: 3 additions & 3 deletions replicate/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing_extensions import NotRequired, TypedDict, Unpack, deprecated

from replicate.exceptions import ReplicateException
from replicate.file import upload_file
from replicate.file import base64_encode_file
from replicate.identifier import ModelVersionIdentifier
from replicate.json import encode_json
from replicate.pagination import Page
Expand Down Expand Up @@ -399,7 +399,7 @@ def create(
if input is not None:
input = encode_json(
input,
upload_file=upload_file
upload_file=base64_encode_file
if file_encoding_strategy == "base64"
else lambda file: self._client.files.create(file).urls["get"],
)
Expand Down Expand Up @@ -429,7 +429,7 @@ async def async_create(
if input is not None:
input = encode_json(
input,
upload_file=upload_file
upload_file=base64_encode_file
if file_encoding_strategy == "base64"
else lambda file: asyncio.get_event_loop()
.run_until_complete(self._client.files.async_create(file))
Expand Down
6 changes: 3 additions & 3 deletions replicate/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from typing_extensions import NotRequired, TypedDict, Unpack

from replicate.exceptions import ModelError, ReplicateError
from replicate.file import upload_file
from replicate.file import base64_encode_file
from replicate.json import encode_json
from replicate.pagination import Page
from replicate.resource import Namespace, Resource
Expand Down Expand Up @@ -460,7 +460,7 @@ def create( # type: ignore
if input is not None:
input = encode_json(
input,
upload_file=upload_file
upload_file=base64_encode_file
if file_encoding_strategy == "base64"
else lambda file: self._client.files.create(file).urls["get"],
)
Expand Down Expand Up @@ -552,7 +552,7 @@ async def async_create( # type: ignore
if input is not None:
input = encode_json(
input,
upload_file=upload_file
upload_file=base64_encode_file
if file_encoding_strategy == "base64"
else lambda file: asyncio.get_event_loop()
.run_until_complete(self._client.files.async_create(file))
Expand Down
6 changes: 3 additions & 3 deletions replicate/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from typing_extensions import NotRequired, Unpack

from replicate.file import upload_file
from replicate.file import base64_encode_file
from replicate.identifier import ModelVersionIdentifier
from replicate.json import encode_json
from replicate.model import Model
Expand Down Expand Up @@ -283,7 +283,7 @@ def create( # type: ignore
if input is not None:
input = encode_json(
input,
upload_file=upload_file
upload_file=base64_encode_file
if file_encoding_strategy == "base64"
else lambda file: self._client.files.create(file).urls["get"],
)
Expand Down Expand Up @@ -324,7 +324,7 @@ async def async_create(
if input is not None:
input = encode_json(
input,
upload_file=upload_file
upload_file=base64_encode_file
if file_encoding_strategy == "base64"
else lambda file: asyncio.get_event_loop()
.run_until_complete(self._client.files.async_create(file))
Expand Down
32 changes: 32 additions & 0 deletions tests/test_file.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import io
import tempfile

import pytest

import replicate
from replicate.file import base64_encode_file


@pytest.mark.vcr("file-operations.yaml")
Expand Down Expand Up @@ -56,3 +58,33 @@ async def test_file_operations(async_flag):
file_list = replicate.files.list()

assert all(f.id != file_id for f in file_list)


@pytest.mark.parametrize(
"content, filename, expected",
[
(b"Hello, World!", "test.txt", "data:text/plain;base64,SGVsbG8sIFdvcmxkIQ=="),
(b"\x89PNG\r\n\x1a\n", "image.png", "data:image/png;base64,iVBORw0KGgo="),
(
"{'key': 'value'}",
"data.json",
"data:application/json;base64,eydrZXknOiAndmFsdWUnfQ==",
),
(
b"Random bytes",
None,
"data:application/octet-stream;base64,UmFuZG9tIGJ5dGVz",
),
],
)
def test_base64_encode_file(content, filename, expected):
# Create a file-like object with the given content
file = io.BytesIO(content if isinstance(content, bytes) else content.encode())

# Set the filename if provided
if filename:
file.name = filename

# Call the function and check the result
result = base64_encode_file(file)
assert result == expected

0 comments on commit e19cb78

Please sign in to comment.