Skip to content

Add support for files API endpoints #226

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ rye = { dev-dependencies = [
] }

[tool.pytest.ini_options]
asyncio_mode = "auto"
testpaths = "tests/"

[tool.setuptools]
Expand Down
3 changes: 2 additions & 1 deletion replicate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
async_paginate = _async_paginate

collections = default_client.collections
hardware = default_client.hardware
deployments = default_client.deployments
files = default_client.files
hardware = default_client.hardware
models = default_client.models
predictions = default_client.predictions
trainings = default_client.trainings
Expand Down
8 changes: 8 additions & 0 deletions replicate/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from replicate.collection import Collections
from replicate.deployment import Deployments
from replicate.exceptions import ReplicateError
from replicate.file import Files
from replicate.hardware import HardwareNamespace as Hardware
from replicate.model import Models
from replicate.prediction import Predictions
Expand Down Expand Up @@ -117,6 +118,13 @@ def deployments(self) -> Deployments:
"""
return Deployments(client=self)

@property
def files(self) -> Files:
"""
Namespace for operations related to files.
"""
return Files(client=self)

@property
def hardware(self) -> Hardware:
"""
Expand Down
204 changes: 204 additions & 0 deletions replicate/file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
import base64
import io
import json
import mimetypes
import os
import pathlib
from typing import Any, BinaryIO, Dict, List, Optional, TypedDict, Union

import httpx
from typing_extensions import NotRequired, Unpack

from replicate.resource import Namespace, Resource


class File(Resource):
"""
A file uploaded to Replicate that can be used as an input to a model.
"""

id: str
"""The ID of the file."""

name: str
"""The name of the file."""

content_type: str
"""The content type of the file."""

size: int
"""The size of the file in bytes."""

etag: str
"""The ETag of the file."""

checksums: Dict[str, str]
"""The checksums of the file."""

metadata: Dict[str, Any]
"""The metadata of the file."""

created_at: str
"""The time the file was created."""

expires_at: Optional[str]
"""The time the file will expire."""

urls: Dict[str, str]
"""The URLs of the file."""


class Files(Namespace):
class CreateFileParams(TypedDict):
"""Parameters for creating a file."""

filename: NotRequired[str]
"""The name of the file."""

content_type: NotRequired[str]
"""The content type of the file."""

metadata: NotRequired[Dict[str, Any]]
"""The file metadata."""

def create(
self,
file: Union[str, pathlib.Path, BinaryIO, io.IOBase],
**params: Unpack["Files.CreateFileParams"],
) -> File:
"""
Upload a file that can be passed as an input when running a model.
"""

if isinstance(file, (str, pathlib.Path)):
with open(file, "rb") as f:
return self.create(f, **params)
elif not isinstance(file, (io.IOBase, BinaryIO)):
raise ValueError(
"Unsupported file type. Must be a file path or file-like object."
)

resp = self._client._request(
"POST", "/v1/files", timeout=None, **_create_file_params(file, **params)
)

return _json_to_file(resp.json())

async def async_create(
self,
file: Union[str, pathlib.Path, BinaryIO, io.IOBase],
**params: Unpack["Files.CreateFileParams"],
) -> File:
"""Upload a file asynchronously that can be passed as an input when running a model."""

if isinstance(file, (str, pathlib.Path)):
with open(file, "rb") as f:
return self.create(f, **params)
elif not isinstance(file, (io.IOBase, BinaryIO)):
raise ValueError(
"Unsupported file type. Must be a file path or file-like object."
)

resp = await self._client._async_request(
"POST", "/v1/files", timeout=None, **_create_file_params(file, **params)
)

return _json_to_file(resp.json())

def get(self, file_id: str) -> File:
"""Get an uploaded file by its ID."""

resp = self._client._request("GET", f"/v1/files/{file_id}")
return _json_to_file(resp.json())

async def async_get(self, file_id: str) -> File:
"""Get an uploaded file by its ID asynchronously."""

resp = await self._client._async_request("GET", f"/v1/files/{file_id}")
return _json_to_file(resp.json())

def list(self) -> List[File]:
"""List all uploaded files."""

resp = self._client._request("GET", "/v1/files")
return [_json_to_file(obj) for obj in resp.json().get("results", [])]

async def async_list(self) -> List[File]:
"""List all uploaded files asynchronously."""

resp = await self._client._async_request("GET", "/v1/files")
return [_json_to_file(obj) for obj in resp.json().get("results", [])]

def delete(self, file_id: str) -> None:
"""Delete an uploaded file by its ID."""

_ = self._client._request("DELETE", f"/v1/files/{file_id}")

async def async_delete(self, file_id: str) -> None:
"""Delete an uploaded file by its ID asynchronously."""

_ = await self._client._async_request("DELETE", f"/v1/files/{file_id}")


def _create_file_params(
file: Union[BinaryIO, io.IOBase],
**params: Unpack["Files.CreateFileParams"],
) -> Dict[str, Any]:
file.seek(0)

if params is None:
params = {}

filename = params.get("filename", os.path.basename(getattr(file, "name", "file")))
content_type = (
params.get("content_type")
or mimetypes.guess_type(filename)[0]
or "application/octet-stream"
)
metadata = params.get("metadata")

data = {}
if metadata:
data["metadata"] = json.dumps(metadata)

return {
"files": {"content": (filename, file, content_type)},
"data": data,
}


def _json_to_file(json: Dict[str, Any]) -> File: # pylint: disable=redefined-outer-name
return File(**json)


def upload_file(file: io.IOBase, output_file_prefix: Optional[str] = None) -> str:
"""
Upload a file to the server.

Args:
file: A file handle to upload.
output_file_prefix: A string to prepend to the output file name.
Returns:
str: A URL to the uploaded file.
"""
# Lifted straight from cog.files

file.seek(0)

if output_file_prefix is not None:
name = getattr(file, "name", "output")
url = output_file_prefix + os.path.basename(name)
resp = httpx.put(url, files={"file": file}, timeout=None) # type: ignore
resp.raise_for_status()

return url

body = file.read()
# Ensure the file handle is in bytes
body = body.encode("utf-8") if isinstance(body, str) else body
encoded_body = base64.b64encode(body).decode("utf-8")
# Use getattr to avoid mypy complaints about io.IOBase having no attribute name
mime_type = (
mimetypes.guess_type(getattr(file, "name", ""))[0] or "application/octet-stream"
)
return f"data:{mime_type};base64,{encoded_body}"
40 changes: 0 additions & 40 deletions replicate/files.py

This file was deleted.

2 changes: 1 addition & 1 deletion replicate/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from typing_extensions import NotRequired, TypedDict, Unpack

from replicate.exceptions import ModelError, ReplicateError
from replicate.files import upload_file
from replicate.file import upload_file
from replicate.json import encode_json
from replicate.pagination import Page
from replicate.resource import Namespace, Resource
Expand Down
2 changes: 1 addition & 1 deletion replicate/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from typing_extensions import NotRequired, Unpack

from replicate.files import upload_file
from replicate.file import upload_file
from replicate.identifier import ModelVersionIdentifier
from replicate.json import encode_json
from replicate.model import Model
Expand Down
Loading