Skip to content

Commit 0020d0f

Browse files
committed
Rename upload_file to base64_encode_file
Remove unused output_file_prefix parameter Add test coverage for base64_encode_file Signed-off-by: Mattt Zmuda <mattt@replicate.com>
1 parent 3cc1974 commit 0020d0f

File tree

6 files changed

+40
-30
lines changed

6 files changed

+40
-30
lines changed

replicate/deployment.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from typing_extensions import Unpack, deprecated
55

66
from replicate.account import Account
7-
from replicate.file import upload_file
7+
from replicate.file import base64_encode_file
88
from replicate.json import encode_json
99
from replicate.pagination import Page
1010
from replicate.prediction import (
@@ -424,7 +424,7 @@ def create(
424424
if input is not None:
425425
input = encode_json(
426426
input,
427-
upload_file=upload_file
427+
upload_file=base64_encode_file
428428
if file_encoding_strategy == "base64"
429429
else lambda file: self._client.files.create(file).urls["get"],
430430
)
@@ -451,7 +451,7 @@ async def async_create(
451451
if input is not None:
452452
input = encode_json(
453453
input,
454-
upload_file=upload_file
454+
upload_file=base64_encode_file
455455
if file_encoding_strategy == "base64"
456456
else lambda file: asyncio.get_event_loop()
457457
.run_until_complete(self._client.files.async_create(file))
@@ -489,7 +489,7 @@ def create(
489489
if input is not None:
490490
input = encode_json(
491491
input,
492-
upload_file=upload_file
492+
upload_file=base64_encode_file
493493
if file_encoding_strategy == "base64"
494494
else lambda file: self._client.files.create(file).urls["get"],
495495
)
@@ -519,7 +519,7 @@ async def async_create(
519519
if input is not None:
520520
input = encode_json(
521521
input,
522-
upload_file=upload_file
522+
upload_file=base64_encode_file
523523
if file_encoding_strategy == "base64"
524524
else lambda file: asyncio.get_event_loop()
525525
.run_until_complete(self._client.files.async_create(file))

replicate/file.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -171,33 +171,23 @@ def _json_to_file(json: Dict[str, Any]) -> File: # pylint: disable=redefined-ou
171171
return File(**json)
172172

173173

174-
def upload_file(file: io.IOBase, output_file_prefix: Optional[str] = None) -> str:
174+
def base64_encode_file(file: io.IOBase) -> str:
175175
"""
176-
Upload a file to the server.
176+
Base64 encode a file.
177177
178178
Args:
179179
file: A file handle to upload.
180-
output_file_prefix: A string to prepend to the output file name.
181180
Returns:
182-
str: A URL to the uploaded file.
181+
str: A base64-encoded data URI.
183182
"""
184-
# Lifted straight from cog.files
185183

186184
file.seek(0)
187-
188-
if output_file_prefix is not None:
189-
name = getattr(file, "name", "output")
190-
url = output_file_prefix + os.path.basename(name)
191-
resp = httpx.put(url, files={"file": file}, timeout=None) # type: ignore
192-
resp.raise_for_status()
193-
194-
return url
195-
196185
body = file.read()
186+
197187
# Ensure the file handle is in bytes
198188
body = body.encode("utf-8") if isinstance(body, str) else body
199189
encoded_body = base64.b64encode(body).decode("utf-8")
200-
# Use getattr to avoid mypy complaints about io.IOBase having no attribute name
190+
201191
mime_type = (
202192
mimetypes.guess_type(getattr(file, "name", ""))[0] or "application/octet-stream"
203193
)

replicate/model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from typing_extensions import NotRequired, TypedDict, Unpack, deprecated
55

66
from replicate.exceptions import ReplicateException
7-
from replicate.file import upload_file
7+
from replicate.file import base64_encode_file
88
from replicate.identifier import ModelVersionIdentifier
99
from replicate.json import encode_json
1010
from replicate.pagination import Page
@@ -399,7 +399,7 @@ def create(
399399
if input is not None:
400400
input = encode_json(
401401
input,
402-
upload_file=upload_file
402+
upload_file=base64_encode_file
403403
if file_encoding_strategy == "base64"
404404
else lambda file: self._client.files.create(file).urls["get"],
405405
)
@@ -429,7 +429,7 @@ async def async_create(
429429
if input is not None:
430430
input = encode_json(
431431
input,
432-
upload_file=upload_file
432+
upload_file=base64_encode_file
433433
if file_encoding_strategy == "base64"
434434
else lambda file: asyncio.get_event_loop()
435435
.run_until_complete(self._client.files.async_create(file))

replicate/prediction.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from typing_extensions import NotRequired, TypedDict, Unpack
2020

2121
from replicate.exceptions import ModelError, ReplicateError
22-
from replicate.file import upload_file
22+
from replicate.file import base64_encode_file
2323
from replicate.json import encode_json
2424
from replicate.pagination import Page
2525
from replicate.resource import Namespace, Resource
@@ -460,7 +460,7 @@ def create( # type: ignore
460460
if input is not None:
461461
input = encode_json(
462462
input,
463-
upload_file=upload_file
463+
upload_file=base64_encode_file
464464
if file_encoding_strategy == "base64"
465465
else lambda file: self._client.files.create(file).urls["get"],
466466
)
@@ -552,7 +552,7 @@ async def async_create( # type: ignore
552552
if input is not None:
553553
input = encode_json(
554554
input,
555-
upload_file=upload_file
555+
upload_file=base64_encode_file
556556
if file_encoding_strategy == "base64"
557557
else lambda file: asyncio.get_event_loop()
558558
.run_until_complete(self._client.files.async_create(file))

replicate/training.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from typing_extensions import NotRequired, Unpack
1616

17-
from replicate.file import upload_file
17+
from replicate.file import base64_encode_file
1818
from replicate.identifier import ModelVersionIdentifier
1919
from replicate.json import encode_json
2020
from replicate.model import Model
@@ -283,7 +283,7 @@ def create( # type: ignore
283283
if input is not None:
284284
input = encode_json(
285285
input,
286-
upload_file=upload_file
286+
upload_file=base64_encode_file
287287
if file_encoding_strategy == "base64"
288288
else lambda file: self._client.files.create(file).urls["get"],
289289
)
@@ -324,7 +324,7 @@ async def async_create(
324324
if input is not None:
325325
input = encode_json(
326326
input,
327-
upload_file=upload_file
327+
upload_file=base64_encode_file
328328
if file_encoding_strategy == "base64"
329329
else lambda file: asyncio.get_event_loop()
330330
.run_until_complete(self._client.files.async_create(file))

tests/test_file.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import tempfile
2+
import io
23

34
import pytest
45

56
import replicate
6-
7+
from replicate.file import base64_encode_file
78

89
@pytest.mark.vcr("file-operations.yaml")
910
@pytest.mark.asyncio
@@ -56,3 +57,22 @@ async def test_file_operations(async_flag):
5657
file_list = replicate.files.list()
5758

5859
assert all(f.id != file_id for f in file_list)
60+
61+
62+
@pytest.mark.parametrize("content, filename, expected", [
63+
(b"Hello, World!", "test.txt", "data:text/plain;base64,SGVsbG8sIFdvcmxkIQ=="),
64+
(b"\x89PNG\r\n\x1a\n", "image.png", ""),
65+
("{'key': 'value'}", "data.json", "data:application/json;base64,eydrZXknOiAndmFsdWUnfQ=="),
66+
(b"Random bytes", None, "data:application/octet-stream;base64,UmFuZG9tIGJ5dGVz"),
67+
])
68+
def test_base64_encode_file(content, filename, expected):
69+
# Create a file-like object with the given content
70+
file = io.BytesIO(content if isinstance(content, bytes) else content.encode())
71+
72+
# Set the filename if provided
73+
if filename:
74+
file.name = filename
75+
76+
# Call the function and check the result
77+
result = base64_encode_file(file)
78+
assert result == expected

0 commit comments

Comments
 (0)