diff --git a/pyproject.toml b/pyproject.toml index 341846ff..12d88b00 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,7 @@ rye = { dev-dependencies = [ ] } [tool.pytest.ini_options] +asyncio_mode = "auto" testpaths = "tests/" [tool.setuptools] diff --git a/replicate/__init__.py b/replicate/__init__.py index c9159cc8..0e6838d7 100644 --- a/replicate/__init__.py +++ b/replicate/__init__.py @@ -14,8 +14,9 @@ async_paginate = _async_paginate collections = default_client.collections -hardware = default_client.hardware deployments = default_client.deployments +files = default_client.files +hardware = default_client.hardware models = default_client.models predictions = default_client.predictions trainings = default_client.trainings diff --git a/replicate/client.py b/replicate/client.py index b28da8b0..5cee7a6f 100644 --- a/replicate/client.py +++ b/replicate/client.py @@ -24,6 +24,7 @@ from replicate.collection import Collections from replicate.deployment import Deployments from replicate.exceptions import ReplicateError +from replicate.file import Files from replicate.hardware import HardwareNamespace as Hardware from replicate.model import Models from replicate.prediction import Predictions @@ -117,6 +118,13 @@ def deployments(self) -> Deployments: """ return Deployments(client=self) + @property + def files(self) -> Files: + """ + Namespace for operations related to files. + """ + return Files(client=self) + @property def hardware(self) -> Hardware: """ diff --git a/replicate/file.py b/replicate/file.py new file mode 100644 index 00000000..f489ae76 --- /dev/null +++ b/replicate/file.py @@ -0,0 +1,204 @@ +import base64 +import io +import json +import mimetypes +import os +import pathlib +from typing import Any, BinaryIO, Dict, List, Optional, TypedDict, Union + +import httpx +from typing_extensions import NotRequired, Unpack + +from replicate.resource import Namespace, Resource + + +class File(Resource): + """ + A file uploaded to Replicate that can be used as an input to a model. + """ + + id: str + """The ID of the file.""" + + name: str + """The name of the file.""" + + content_type: str + """The content type of the file.""" + + size: int + """The size of the file in bytes.""" + + etag: str + """The ETag of the file.""" + + checksums: Dict[str, str] + """The checksums of the file.""" + + metadata: Dict[str, Any] + """The metadata of the file.""" + + created_at: str + """The time the file was created.""" + + expires_at: Optional[str] + """The time the file will expire.""" + + urls: Dict[str, str] + """The URLs of the file.""" + + +class Files(Namespace): + class CreateFileParams(TypedDict): + """Parameters for creating a file.""" + + filename: NotRequired[str] + """The name of the file.""" + + content_type: NotRequired[str] + """The content type of the file.""" + + metadata: NotRequired[Dict[str, Any]] + """The file metadata.""" + + def create( + self, + file: Union[str, pathlib.Path, BinaryIO, io.IOBase], + **params: Unpack["Files.CreateFileParams"], + ) -> File: + """ + Upload a file that can be passed as an input when running a model. + """ + + if isinstance(file, (str, pathlib.Path)): + with open(file, "rb") as f: + return self.create(f, **params) + elif not isinstance(file, (io.IOBase, BinaryIO)): + raise ValueError( + "Unsupported file type. Must be a file path or file-like object." + ) + + resp = self._client._request( + "POST", "/v1/files", timeout=None, **_create_file_params(file, **params) + ) + + return _json_to_file(resp.json()) + + async def async_create( + self, + file: Union[str, pathlib.Path, BinaryIO, io.IOBase], + **params: Unpack["Files.CreateFileParams"], + ) -> File: + """Upload a file asynchronously that can be passed as an input when running a model.""" + + if isinstance(file, (str, pathlib.Path)): + with open(file, "rb") as f: + return self.create(f, **params) + elif not isinstance(file, (io.IOBase, BinaryIO)): + raise ValueError( + "Unsupported file type. Must be a file path or file-like object." + ) + + resp = await self._client._async_request( + "POST", "/v1/files", timeout=None, **_create_file_params(file, **params) + ) + + return _json_to_file(resp.json()) + + def get(self, file_id: str) -> File: + """Get an uploaded file by its ID.""" + + resp = self._client._request("GET", f"/v1/files/{file_id}") + return _json_to_file(resp.json()) + + async def async_get(self, file_id: str) -> File: + """Get an uploaded file by its ID asynchronously.""" + + resp = await self._client._async_request("GET", f"/v1/files/{file_id}") + return _json_to_file(resp.json()) + + def list(self) -> List[File]: + """List all uploaded files.""" + + resp = self._client._request("GET", "/v1/files") + return [_json_to_file(obj) for obj in resp.json().get("results", [])] + + async def async_list(self) -> List[File]: + """List all uploaded files asynchronously.""" + + resp = await self._client._async_request("GET", "/v1/files") + return [_json_to_file(obj) for obj in resp.json().get("results", [])] + + def delete(self, file_id: str) -> None: + """Delete an uploaded file by its ID.""" + + _ = self._client._request("DELETE", f"/v1/files/{file_id}") + + async def async_delete(self, file_id: str) -> None: + """Delete an uploaded file by its ID asynchronously.""" + + _ = await self._client._async_request("DELETE", f"/v1/files/{file_id}") + + +def _create_file_params( + file: Union[BinaryIO, io.IOBase], + **params: Unpack["Files.CreateFileParams"], +) -> Dict[str, Any]: + file.seek(0) + + if params is None: + params = {} + + filename = params.get("filename", os.path.basename(getattr(file, "name", "file"))) + content_type = ( + params.get("content_type") + or mimetypes.guess_type(filename)[0] + or "application/octet-stream" + ) + metadata = params.get("metadata") + + data = {} + if metadata: + data["metadata"] = json.dumps(metadata) + + return { + "files": {"content": (filename, file, content_type)}, + "data": data, + } + + +def _json_to_file(json: Dict[str, Any]) -> File: # pylint: disable=redefined-outer-name + return File(**json) + + +def upload_file(file: io.IOBase, output_file_prefix: Optional[str] = None) -> str: + """ + Upload a file to the server. + + Args: + file: A file handle to upload. + output_file_prefix: A string to prepend to the output file name. + Returns: + str: A URL to the uploaded file. + """ + # Lifted straight from cog.files + + file.seek(0) + + if output_file_prefix is not None: + name = getattr(file, "name", "output") + url = output_file_prefix + os.path.basename(name) + resp = httpx.put(url, files={"file": file}, timeout=None) # type: ignore + resp.raise_for_status() + + return url + + body = file.read() + # Ensure the file handle is in bytes + body = body.encode("utf-8") if isinstance(body, str) else body + encoded_body = base64.b64encode(body).decode("utf-8") + # Use getattr to avoid mypy complaints about io.IOBase having no attribute name + mime_type = ( + mimetypes.guess_type(getattr(file, "name", ""))[0] or "application/octet-stream" + ) + return f"data:{mime_type};base64,{encoded_body}" diff --git a/replicate/files.py b/replicate/files.py deleted file mode 100644 index e761ed3f..00000000 --- a/replicate/files.py +++ /dev/null @@ -1,40 +0,0 @@ -import base64 -import io -import mimetypes -import os -from typing import Optional - -import httpx - - -def upload_file(file: io.IOBase, output_file_prefix: Optional[str] = None) -> str: - """ - Upload a file to the server. - - Args: - file: A file handle to upload. - output_file_prefix: A string to prepend to the output file name. - Returns: - str: A URL to the uploaded file. - """ - # Lifted straight from cog.files - - file.seek(0) - - if output_file_prefix is not None: - name = getattr(file, "name", "output") - url = output_file_prefix + os.path.basename(name) - resp = httpx.put(url, files={"file": file}, timeout=None) # type: ignore - resp.raise_for_status() - - return url - - body = file.read() - # Ensure the file handle is in bytes - body = body.encode("utf-8") if isinstance(body, str) else body - encoded_body = base64.b64encode(body).decode("utf-8") - # Use getattr to avoid mypy complaints about io.IOBase having no attribute name - mime_type = ( - mimetypes.guess_type(getattr(file, "name", ""))[0] or "application/octet-stream" - ) - return f"data:{mime_type};base64,{encoded_body}" diff --git a/replicate/prediction.py b/replicate/prediction.py index 74c1946e..055d9ca2 100644 --- a/replicate/prediction.py +++ b/replicate/prediction.py @@ -19,7 +19,7 @@ from typing_extensions import NotRequired, TypedDict, Unpack from replicate.exceptions import ModelError, ReplicateError -from replicate.files import upload_file +from replicate.file import upload_file from replicate.json import encode_json from replicate.pagination import Page from replicate.resource import Namespace, Resource diff --git a/replicate/training.py b/replicate/training.py index 8125cdf3..83413ef2 100644 --- a/replicate/training.py +++ b/replicate/training.py @@ -13,7 +13,7 @@ from typing_extensions import NotRequired, Unpack -from replicate.files import upload_file +from replicate.file import upload_file from replicate.identifier import ModelVersionIdentifier from replicate.json import encode_json from replicate.model import Model diff --git a/tests/cassettes/file-operations.yaml b/tests/cassettes/file-operations.yaml new file mode 100644 index 00000000..076bea45 --- /dev/null +++ b/tests/cassettes/file-operations.yaml @@ -0,0 +1,338 @@ +interactions: +- request: + body: "--f64a6c6ac4fed507b9430635c33d4f7c\r\nContent-Disposition: form-data; name=\"content\"; + filename=\"test_fileo20o3wth.txt\"\r\nContent-Type: text/plain\r\n\r\n\r\n--f64a6c6ac4fed507b9430635c33d4f7c--\r\n" + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '186' + content-type: + - multipart/form-data; boundary=f64a6c6ac4fed507b9430635c33d4f7c + host: + - api.replicate.com + user-agent: + - replicate-python/0.30.1 + method: POST + uri: https://api.replicate.com/v1/files + response: + body: + string: '{"id":"NGZiNmY2YzQtMThhZi00ZjcyLWFhZjktODg4NTY0NWNlMDEy","name":"test_fileo20o3wth.txt","content_type":"text/plain","size":0,"etag":"d41d8cd98f00b204e9800998ecf8427e","checksums":{"sha256":"e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855","md5":"d41d8cd98f00b204e9800998ecf8427e"},"metadata":{},"created_at":"2024-08-22T12:26:51.079Z","expires_at":"2024-08-23T12:26:51.079Z","urls":{"get":"https://api.replicate.com/v1/files/NGZiNmY2YzQtMThhZi00ZjcyLWFhZjktODg4NTY0NWNlMDEy"}}' + headers: + CF-RAY: + - 8b72da83eda0c39f-SEA + Cf-Placement: + - remote-SJC + Connection: + - keep-alive + Content-Length: + - '493' + Content-Type: + - application/json; charset=UTF-8 + Date: + - Thu, 22 Aug 2024 12:26:51 GMT + NEL: + - '{"success_fraction":0,"report_to":"cf-nel","max_age":604800}' + Report-To: + - '{"endpoints":[{"url":"https:\/\/a.nel.cloudflare.com\/report\/v4?s=3zHX3amjkYcsqliQHujgqvEty9JAbDn1rvG00PP4bfnbsWV11zLOspVtSDl%2F3XyH2ECydfNrW8t9qj4QSmX%2F2jkeas0Xi18PbABraUJigcXzjaNwlh2gQ%2BtRyM4hgS8jcZvO"}],"group":"cf-nel","max_age":604800}' + Server: + - cloudflare + Strict-Transport-Security: + - max-age=15552000 + Vary: + - Accept-Encoding + alt-svc: + - h3=":443"; ma=86400 + status: + code: 201 + message: Created +- request: + body: '' + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + host: + - api.replicate.com + user-agent: + - replicate-python/0.30.1 + method: GET + uri: https://api.replicate.com/v1/files/NGZiNmY2YzQtMThhZi00ZjcyLWFhZjktODg4NTY0NWNlMDEy + response: + body: + string: !!binary | + H4sIAAAAAAAAAwAAAP//lNHLTsMwEAXQf5l1SBzHSWyvC2xoEFKlqtlUjjNp3OaleAp9qP+OCjvY + wP7oau6dK7gaNBTPpSv6Dd9c3mi5atvSMVbu7fll/dSW+wO9LnaiWG1YsS665eLxDAEMpkfQQOhp + 27gOR87G5IPakE4EAdhxIBxoS+fpm50omjrjBgjAuwuCZgEgmR1oqEVcS1sr2TBWcSZQScaUkmgb + KXiO97gW7cEfew/6Cr41PM1AAyYVs0JwJRsb21go01SNsFKprKkUFzw3KGIUmVCVSoQ1QqVKxVUu + U17JNIUA+jr9ywW3AHokUxsyoK+3AOyMhrDeGgINnHHxwOQD56uYa57pNA5ZrkoIAE+Tm9H/cMkv + d5y7r247vMOWaPI6iszkwhmnzllDGNqxj97j6L62j/79stvtEwAA//8DAFK0/HbtAQAA + headers: + CF-RAY: + - 8b72da85be95c39f-SEA + Cf-Placement: + - remote-SJC + Connection: + - keep-alive + Content-Encoding: + - gzip + Content-Type: + - application/json; charset=UTF-8 + Date: + - Thu, 22 Aug 2024 12:26:51 GMT + NEL: + - '{"success_fraction":0,"report_to":"cf-nel","max_age":604800}' + Report-To: + - '{"endpoints":[{"url":"https:\/\/a.nel.cloudflare.com\/report\/v4?s=PTe6WTpBSmtfJrGQ5lngNg%2FC6DkfYk%2BCWVjYfW4g0eTVVFGTS1D1w9a59eGHpaGhdSXtARBi8O21Imju21kTMjQ8F%2BDMxqYZZQwKoCJV9VDGu0bCLVfCCzlnB7vVaIZhrzZU"}],"group":"cf-nel","max_age":604800}' + Server: + - cloudflare + Strict-Transport-Security: + - max-age=15552000 + Transfer-Encoding: + - chunked + Vary: + - Accept-Encoding + alt-svc: + - h3=":443"; ma=86400 + status: + code: 200 + message: OK +- request: + body: '' + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + host: + - api.replicate.com + user-agent: + - replicate-python/0.30.1 + method: GET + uri: https://api.replicate.com/v1/files + response: + body: + string: !!binary | + H4sIAAAAAAAAA+zcUW+jOhYA4P/Cc5cxxiQhb5lLgohiI+Y6QfZqVaWEBgwkaetOAqP57yvT3ind + RUrcZx47Ooym8wWf4+Pj/DIO6UUa08NrWd4Zp+f0Z358ffnn5+f05bWUL8b037+MfGdMDQwji8DI + 4iKSLGYA5wCQJstXtHAYLSRpcINpAnGT1JgmtXFnHLZVakwNmb7I+8e8TEc/k4OQD44pL9K4M5Lj + QaYHeS/r01vYRX47ldv8YNwZL3mTGlNwZ6Ryuzemxg5Zu0mycyePADxAgFJ3AoDrTtLkcYLgOFV/ + XZYmxctr9WJMfxkv2RY6I2NqpPYDSBCC7uQxsRILudvHh0eUTFx39PjgQgTH2xRZKRoh98G1UbJF + ruO61sN44sCHieMYd0a1c275F/y+M6pUbndbuTWmv37fGclzupXp7n4rjakBAUT/ApN/QUgtOIVw + iqCJXMSNOyO9nPLn9OV/4uz/i3t9LtvfbZ+qwEzK08v027ftKTef01OZJ1uZmsmx+vbT+qb+t1++ + aZP9/n33j7a3djANLO4FMvR/CJYDwGlWruLgzGIsSbW2MA1qXi1LJrDVo+28HB4eygMctKEztR3T + tuA17W6cvrYu2SftMxZlwZqZxDBo8N/q0e/5iu7bDw8Xi4pV+Mzj4EK8rOjRftk9CblDctBW76xt + Aji+4d3+E/cVbT2yrnaCuAhsTLEMvcJifwOA6bJc0b2N/bkkHruEXgTV8sGrCPRoo71Mt9mTPWhD + NIWWORpdXcm7cV/Q1iTraPtcYG9XELqWhOKGt48uqlW8zMJ4Lrk3a7AXnHkVIeLNYI/2+PU03j1n + T4O25U4hNMcj94r2pzh9bV2yjjYNYEjLgkH1QdnkuE0CSbOKlzn31hJXm4r7G4Hj9RnTXdWjfbHy + Au7tn4M2tKcImS64pv0pTl9bl6yjLTZV6DGbwEgSyBDPAcA+blbxvOG++rM5JB6DpJnVTGRlX95G + p2L8OnkYtFX1hczJ+JYq7U+cvrYuWUe72RQEzmuikgDd5OFfADA6Q6t4UeBmLsOYASYSwOkesqb3 + 3X7a7Zpn6+wO2tZkCsbmeGJfW8m7cfraumQf2gSubR6v7ZDOJKeJjWsAWMzzVRw0DAaSQOyE8Rpi + n5c8XvbV5I+Hqj5mT8O73eZj24QI3JC3/8Rpa2uTdbS9AoTeXqV0yekckr8AwGJur2iWhWpZoHuA + /ehMvEXB/aCvJhdH20lP52rQhs4UWKZlj66v5B9x+tq6ZB1tn+ekYpA1kcQ0y1QS4CKpV/Ei46KQ + obdHhDJAYlJib97XSztCcLTPMhu04WjqWCYYX63SunH62rpkHW2a1NyfI1ZhVbrbRLXhvKVqzDht + s8abXbBK+TTLcLXJerQn++Q0TqqXQRuiqW2b1tUd2Kc4fW1dsq52JrBXtqU7a9a1Svmhh8+reH7G + qllTzWviR5DRtcPFsq9KewZZvs8vQ598ao3anZV9rSb/FPcFbU2yjnZMckKjCxYzSWiCVBIgzSZv + H40jyf3ADr0yD/3Awj7pe7cvx1MtUzhUaeqdBcAco8n1d/sjTl9bl6yrzQvuRTaPA0nE/qKWBUx3 + xSqeW8Tby9AjFfZmDhOkxA1GPdpnNztZ9XM6aEOg8rE7vlqldeO+oK1J1tEWczv05zYRa8mqjeB/ + AcA9tXkLHJUYQm995jQ4Y7GsmFj3dU73B0s8JenjoG2NphCYELnXV/KPOH1tXbKudgFJjM+kYZLH + S8FqALAXoFU8v2BvJomIGtZ8L0Of1dzb9PXSyonzNDlmQ3dF1doAmej6GVg37gvammQd7WaTEbpH + bYM1Zk7bYq/W9YoWdujNJWvYGXuFxb1djmOS92jXqKjOaTKs5O8r9AjctpK/x+lr65J9aIfessAx + QyyeSywYUkfjrCmrVbyoQppI1mQZ9heCUWaH3rrp0S5eM2ufJJdB23KnyDLd653Tbpy2tjZZV5uU + PI4sRlXL7e1onEEGVqoh0ySSiTJjcN7gKnBCurZ7tPNJZTn3x2G/3Z5bT8zJ6Iaa/CPuC9qaZF3t + ssBikxE/ktgP6rac93ZiFS8LJpjEIstxs69D73vFqnlf53QydnaPdTlMM7Q7K8uErnXDDuxP3Be0 + Nck62hQ7WCwz0obhRqV8Qn+UK1pceLyW2IscDJcZo6zBXtS3kmentHi6fz0M2hBM0cgcjezrefsj + Tl9bl6yrXVjcj+p2WaCsbg9UxMxaxXMn9JlkzUYQL0Kk2gjsRZe+SaVLvt3nmTVoQzC1Rya63jnt + xn1BW5PsQ5vBTcaqRcFEInEcQLV5IxUvV/FCkCqQuArOhDKb+POaC97XS0OncSlcdzxoq3d2bALH + ueHd/hOnra1N1tGmaiw1uKhBZNxEALePRvaKlkJt3li1yDBcVMTblSTu7aUdmnvgnuzXQdsaTS3X + HE9u6K58xOlr65J1tAWviFgK1ZjBcF0TNa4a/yjVkBOLAxnGC8EqDLjIShL/6DsVOb4WtSvT46D9 + dm7tXO+udOP0tXXJOtrqwMTjuWqxYzFrQlXO06hexcuKtcMR30sumMPUZs4P+jqnkySpdk/3Q02u + 9tG2bQJww8zpR5y+ti5ZV3shOI1sBiMZ0uCiHiViJ1Z0X+O2wEsgo3uHq5RPZ+e+CePdMcsv+XbQ + hgrSdNDo+u2Bj7gvaGuSfWhzL4JqWWCUSVIRoQ5UmCjeBhjpTJJm3zDIC05niHu9k0oTeyR2iRy6 + K+25tWOiqyeen+K0tbXJuto/cizmMKTqSsmuUkkAewFUByqhF8m2ceMlZ9wU6tpJn/YYpYcjsJtB + W82STkw4vmXm9E/cF7Q1yTra/qYkcdCocVVGfxTtow27rGhWEahGYdY1aXZF6LEm9IK+W3/lT3A5 + WPkwzfCejxEEN+Xt9zh9bV2yjjYNHOJHgFMsSbws2gFGGp1XbYm/lthXl4kWgsW8wv3aL5f9Nj3e + D1OI7WkHNJ2rZ2Cf4vS1dcm62tjCVXQhqgkDmfOWBNa2GkUnNJCqtCf+plBFXuiv+/bb7v7y5Ga7 + 4R7Ye49sgkY39dIm6KvammRd7X2DvU3JK6wuEIn2wiCcgxXNcpUEmCCC+D8yLApERO+tv31TjdPd + 9n7QfjvbcgC66QzsPe4L2ppkXe3Cxg2pwhhL7C2LdhCiWaO34zOVBBgiXmCxprBCr+j9boaHCTw/ + JcWg/XZTFyLnphu973Ff0NYk62iLec0aXhI4l1zMnTYJ+BFSHxRcRZJBkoXxIgv9+YV72Olbycvi + URztYS7tfU58ZNk3zZO/x+lr65J1tNv2+rJqv8KhWbQFnromqlrs6vI/qZhF6P7MPV6E8bpPuynA + w6GydoP2W9fEujqF+ClOX1uXrKtNBIs3JfeZJGo0VY2rNntHfVCISgz+uubV2g79CIRx1FeTZz+d + 5vScokH77V42cp2b7m+/x31BW5Ps9+///P4vAAAA//8DAJIltuJ9SQAA + headers: + CF-RAY: + - 8b72da871f4bc39f-SEA + Cf-Placement: + - remote-SJC + Connection: + - keep-alive + Content-Encoding: + - gzip + Content-Type: + - application/json; charset=UTF-8 + Date: + - Thu, 22 Aug 2024 12:26:51 GMT + NEL: + - '{"success_fraction":0,"report_to":"cf-nel","max_age":604800}' + Report-To: + - '{"endpoints":[{"url":"https:\/\/a.nel.cloudflare.com\/report\/v4?s=bYdBu1NoyhwOzfM%2FJTv3B8BsO45cdyXYsf6ZvUqgkivjk1wKNmuJt9xdFE8UGhUwN%2BdRNTpY1ZyqGsbugeFFKEDNoWgAZERBDVV1EaD%2F1oF2%2FoF6BfFVjkuKXpctEhqgNFTr"}],"group":"cf-nel","max_age":604800}' + Server: + - cloudflare + Strict-Transport-Security: + - max-age=15552000 + Transfer-Encoding: + - chunked + Vary: + - Accept-Encoding + alt-svc: + - h3=":443"; ma=86400 + status: + code: 200 + message: OK +- request: + body: '' + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + host: + - api.replicate.com + user-agent: + - replicate-python/0.30.1 + method: DELETE + uri: https://api.replicate.com/v1/files/NGZiNmY2YzQtMThhZi00ZjcyLWFhZjktODg4NTY0NWNlMDEy + response: + body: + string: '' + headers: + CF-RAY: + - 8b72da887ff6c39f-SEA + Cf-Placement: + - remote-SJC + Connection: + - keep-alive + Content-Type: + - text/plain; charset=UTF-8 + Date: + - Thu, 22 Aug 2024 12:26:51 GMT + NEL: + - '{"success_fraction":0,"report_to":"cf-nel","max_age":604800}' + Report-To: + - '{"endpoints":[{"url":"https:\/\/a.nel.cloudflare.com\/report\/v4?s=jeTkrWxENn3H8D%2FYY8FS6nSJ9rjBM%2FuyKzHI3JiWhIYUU7Jhv%2BHXNCHyufvnz2OZ7AfRmZSTSqX4vuNHt3jEyHGha669CxzHHNenVS0T79pmWYYiqEpZ0jngN73oqEptcnIn"}],"group":"cf-nel","max_age":604800}' + Server: + - cloudflare + Strict-Transport-Security: + - max-age=15552000 + Vary: + - Accept-Encoding + alt-svc: + - h3=":443"; ma=86400 + status: + code: 204 + message: No Content +- request: + body: '' + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + host: + - api.replicate.com + user-agent: + - replicate-python/0.30.1 + method: GET + uri: https://api.replicate.com/v1/files + response: + body: + string: !!binary | + H4sIAAAAAAAAA+zbUW+jOhYA4P/Cc5cxxiQhb5lLgohiI+Y6QfZqVaWEBgykaeNOAqP57yvT3ind + RUrcZx5bHaq2X/A5Pj7+ZRzSizSmh9eyvDOOL+nP/On19M/XL+nptZQnY/rvX0a+M6YGhpFFYGRx + EUkWM4BzAEiT5StaOIwWkjS4wTSBuElqTJPauDMO2yo1poZMT/L+MS/T0c/kIOSDY8qLNO6M5Okg + 04O8l/XxLewivx3LbX4w7oxT3qTGFNwZqdzujamxQ9ZukuzcySMADxCg1J0A4LqTNHmcIDhO1Y/L + 0qQ4vVYnY/rLOGVb6IyMqZHaDyBBCLqTx8RKLORuHx8eUTJx3dHjgwsRHG9TZKVohNwH10bJFrmO + 61oP44kDHyaOY9wZ1c655Tf4fWdUqdzutnJrTH/9vjOSl3Qr0939VhpTAwKI/gUm/4KQWnAK4RRB + E7mIG3dGejnmL+npf+Ls/4t7fSnbv22fqsBMyuNp+u3b9pibL+mxzJOtTM3kqfr20/qm/tunb9pk + v3/f/aPtrR1MA4t7gQz9H4LlAHCalas4OLMYS1KtLUyDmlfLkgls9Wg7p8PDQ3mAgzZ0prZj2ha8 + pt2N09fWJfukfcaiLFgzkxgGDf5bPfo9X9F9++HhYlGxCp95HFyIlxU92qfds5A7JAdt9c7aJoDj + G97tP3Ff0dYj62oniIvAxhTL0Css9jcAmC7LFd3b2J9L4rFL6EVQLR+8ikCPNtrLdJs924M2RFNo + maPR1ZW8G/cFbU2yjrbPBfZ2BaFrSShuePvoolrFyyyM55J7swZ7wZlXESLeDPZoj1+P491L9jxo + W+4UQnM8cq9of4rT19Yl62jTAIa0LBhUH5RNjtskkDSreJlzby1xtam4vxE4Xp8x3VU92hcrL+De + /jloQ3uKkOmCa9qf4vS1dck62mJThR6zCYwkgQzxHADs42YVzxvuq+/NIfEYJM2sZiIr+/I2Ohbj + 18nDoK2qL2ROxrdUaX/i9LV1yTrazaYgcF4TlQToJg//AoDRGVrFiwI3cxnGDDCRAE73kDW97/bz + bte8WGd30LYmUzA2xxP72krejdPX1iX70CZwbfN4bYd0JjlNbFwDwGKer+KgYTCQBGInjNcQ+7zk + 8bKvJn88VPVT9jy8220+tk2IwA15+0+ctrY2WUfbK0Do7VVKl5zOIfkLACzm9opmWaiWBboH2I/O + xFsU3A/6anLxZDvp8VwN2tCZAsu07NH1lfwjTl9bl6yjTZOa+3PEKqyKOZuoxoy3VFt1p92+e7ML + VkmAZhmuNlmP9mSfHMdJdRq0IZratmldrck/xelr65J1tTOBvbIt5lizrlUSCD18XsXzM1bb92pe + Ez+CjK4dLpZ9efsFZPk+vwyd06k1amtt+1qV9inuC9qaZB3tmOSERhcsZpLQBKllgTSbvH00jiT3 + Azv0yjz0Awv7pO/dvjwda5nCIW+rdxYAc4wm19/tjzh9bV2yrjYvuBfZPA4kEfuLWhYw3RWreG4R + by9Dj1TYmzlMkBI3GPVon93saNUv6aANwdSxTHd8NW93476grUnW0RZzO/TnNhFryaqN4H8BwD1V + zgeOSgyhtz5zGpyxWFZMrPt6afuDJZ6T9HHQtkZTCEyI3Osr+UecvrYuWVe7gCTGZ9IwyeOlYDUA + 2AvQKp5fsDeTREQNa76Xoc9q7m36uivlxHmePGXDfnsKR1OATHT9VKQb9wVtTbKOdrPJCN2jtuUW + M6dtulbrekULO/TmkjXsjL3C4t4uxzHJe7RrVFTnNBlW8vcVegRuW8nf4/S1dck+tENvWeCYIRbP + JRYMqcNS1pTVKl5UIU0ka7IM+wvBKLNDb930aBevmbVPksugbblTZJnu9V5aN05bW5usq01KHkcW + o6oJ83ZYyiADK7VFbxLJRJkxOG9wFTghXds92vmkspz7p2G/3Z5kTszJ6Iaa/CPuC9qaZF3tssBi + kxE/ktgP6rac93ZiFS8LJpjEIstxs69D73vFqnlfL20ydnaPdTmcb7c7K8uErnXDDuxP3Be0Nck6 + 2hQ7WCwz0obhRqV8Qn+UK1pceLyW2IscDJcZo6zBXtS3kmfHtHi+fz0M2hBM0cgcjezrefsjTl9b + l6yrXVjcj+p2WaCsblvsYmat4rkT+kyyZiOIFyFSbQT2okvf7Mol3+7zzBq0IZjaIxNd75x2476g + rUn2oc3gJmPVomAikTgOoNq8kYqXq3ghSBVIXAVnQplN/HnNBe/rpaHjuBSuOx601Ts7NoHj3PBu + /4nT1tYm62hTNagYXNRoKm4igNtHI3tFS6E2b6xaZBguKuLtShL39tIOzT1wj/broG2NppZrjic3 + dFc+4vS1dck62oJXRCyFasxguK6JGmCMf5Rq7IXFgQzjhWAVBlxkJYl/9J2KPL0WtSvTp0H77STT + ud5d6cbpa+uSdbTVgYnHc9Vix2LWhKqcp1G9ipcVa4/Lv5dcMIepzZwf9HVOJ0lS7Z7vh5pc7aNt + 2wTghinEjzh9bV2yrvZCcBrZDEYypMFFPUrETqzovsZtgZdARvcOVymfzs59M6e7pyy/5NtBGypI + 00Gj6/PkH3Ff0NYk+9DmXgTVssAok6QiQh2oMFG8jbTRmSTNvmGQF5zOEPd6Z1cm9kjsEjl0V9pz + a8dEV088P8Vpa2uTdbV/5FjMYUjVJYNdpZIA9gKoDlRCL5Jt48ZLzrgp1EWEPu0xSg9PwG4GbTVd + ODHh+JYpxD9xX9DWJOto+5uSxEGjBhgZ/VG0jzbssqJZRaAahVnXpNkVocea0Av67oGVP8HlYOXD + NMN7PkYQ3JS33+P0tXXJOto0cIgfAU6xJPGyaEfaaHRetSX+WmJfXS9ZCBbzCvdrny77bfp0nw3a + 6rQDms7VM7BPcfraumRdbWzhKroQ1YSBzHlLAmtbDScTGkhV2hN/U6giL/TXffttd395drPdcDPo + vUc2QaObemkT9FVtTbKu9r7B3qbkFVZXSkR7hQzOwYpmuUoCTBBB/B8ZFgUiovce2L6pxuluez9o + v51tOQDddAb2HvcFbU2yrnZh44ZUYYwl9pZFOwjRrNHb8ZlKAgwRL7BYU1ihV/Te1n+YwPNzUgza + b3c3IXJuuuP5HvcFbU2yjraY16zhJYFzycXcaZOAHyH1QcFVJBkkWRgvstCfX7iHnb6VvCwexZM9 + zKW9z4mPLPumefL3OH1tXbKOdtteX1btpf5m0RZ46uKgarGr6+CkYhah+zP3eBHG6z7tpgAPh8ra + DdpvXRPr6hTipzh9bV2yrjYRLN6U3GeSqNFUNa7a7B31QSEqMfjrmldrO/QjEMZRX02e/XSa40uK + Bu23m7rIdW660fse9wVtTbLfv//z+78AAAD//wMATLi5mI9HAAA= + headers: + CF-RAY: + - 8b72da8a28c1c39f-SEA + Cf-Placement: + - remote-SJC + Connection: + - keep-alive + Content-Encoding: + - gzip + Content-Type: + - application/json; charset=UTF-8 + Date: + - Thu, 22 Aug 2024 12:26:52 GMT + NEL: + - '{"success_fraction":0,"report_to":"cf-nel","max_age":604800}' + Report-To: + - '{"endpoints":[{"url":"https:\/\/a.nel.cloudflare.com\/report\/v4?s=pUBDVlUTiRMAnxnBdWk52MfFYaI6Igr0Xii0yU28rQCdZ75ceEZSSGuNqk1yfB4xF1uUv%2BEUcehuEnMoQ9q2d4ZLJRQpJMSKbXRqMbT8VI%2F29sOyJsYRT3%2FQYn5wmXeG8%2FLn"}],"group":"cf-nel","max_age":604800}' + Server: + - cloudflare + Strict-Transport-Security: + - max-age=15552000 + Transfer-Encoding: + - chunked + Vary: + - Accept-Encoding + alt-svc: + - h3=":443"; ma=86400 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/conftest.py b/tests/conftest.py index 103d1693..a29ed640 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,17 +1,7 @@ -import asyncio import os from unittest import mock import pytest -import pytest_asyncio - - -@pytest_asyncio.fixture(scope="session", autouse=True) -def event_loop(): - event_loop_policy = asyncio.get_event_loop_policy() - loop = event_loop_policy.new_event_loop() - yield loop - loop.close() @pytest.fixture(scope="session") diff --git a/tests/test_file.py b/tests/test_file.py new file mode 100644 index 00000000..5e80e939 --- /dev/null +++ b/tests/test_file.py @@ -0,0 +1,58 @@ +import tempfile + +import pytest + +import replicate + + +@pytest.mark.vcr("file-operations.yaml") +@pytest.mark.asyncio +@pytest.mark.parametrize("async_flag", [True, False]) +async def test_file_operations(async_flag): + # Create a sample file + with tempfile.NamedTemporaryFile( + mode="wb", delete=False, prefix="test_file", suffix=".txt" + ) as temp_file: + temp_file.write(b"Hello, Replicate!") + + # Test create + if async_flag: + created_file = await replicate.files.async_create(temp_file.name) + else: + created_file = replicate.files.create(temp_file.name) + + assert created_file.name.startswith("test_file") + assert created_file.name.endswith(".txt") + file_id = created_file.id + + # Test get + if async_flag: + retrieved_file = await replicate.files.async_get(file_id) + else: + retrieved_file = replicate.files.get(file_id) + + assert retrieved_file.id == file_id + + # Test list + if async_flag: + file_list = await replicate.files.async_list() + else: + file_list = replicate.files.list() + + assert file_list is not None + assert len(file_list) > 0 + assert any(f.id == file_id for f in file_list) + + # Test delete + if async_flag: + await replicate.files.async_delete(file_id) + else: + replicate.files.delete(file_id) + + # Verify file is deleted + if async_flag: + file_list = await replicate.files.async_list() + else: + file_list = replicate.files.list() + + assert all(f.id != file_id for f in file_list)