From 058fa2c8e52adb046f3a741ff22b119ab3688d3f Mon Sep 17 00:00:00 2001 From: Andrew Smith Date: Wed, 25 Dec 2024 16:35:35 +0000 Subject: [PATCH] feat: add custom-metadata to upload --- storage3/_async/file_api.py | 24 +++++++++++++++++++----- storage3/_sync/file_api.py | 24 +++++++++++++++++++----- storage3/types.py | 11 +++++++++-- tests/_async/conftest.py | 12 +++++++++--- tests/_async/test_client.py | 27 +++++++++++++++++++++++++-- tests/_sync/conftest.py | 12 +++++++++--- tests/_sync/test_client.py | 23 +++++++++++++++++++++++ 7 files changed, 113 insertions(+), 20 deletions(-) diff --git a/storage3/_async/file_api.py b/storage3/_async/file_api.py index 8529eca6..fe4a1d3e 100644 --- a/storage3/_async/file_api.py +++ b/storage3/_async/file_api.py @@ -1,5 +1,7 @@ from __future__ import annotations +import base64 +import json import urllib.parse from dataclasses import dataclass, field from io import BufferedReader, FileIO @@ -435,11 +437,15 @@ async def _upload_or_update( """ if file_options is None: file_options = {} - cache_control = file_options.get("cache-control") + cache_control = file_options.pop("cache-control", None) _data = {} - if file_options.get("upsert"): - file_options.update({"x-upsert": file_options.get("upsert")}) - del file_options["upsert"] + + upsert = file_options.pop("upsert", None) + if upsert: + file_options.update({"x-upsert": upsert}) + + metadata = file_options.pop("metadata", None) + file_opts_headers = file_options.pop("headers", None) headers = { **self._client.headers, @@ -447,6 +453,14 @@ async def _upload_or_update( **file_options, } + if metadata: + metadata_str = json.dumps(metadata) + headers["x-metadata"] = base64.b64encode(metadata_str.encode()) + _data.update({"metadata": metadata_str}) + + if file_opts_headers: + headers.update({**file_opts_headers}) + # Only include x-upsert on a POST method if method != "POST": del headers["x-upsert"] @@ -455,7 +469,7 @@ async def _upload_or_update( if cache_control: headers["cache-control"] = f"max-age={cache_control}" - _data = {"cacheControl": cache_control} + _data.update({"cacheControl": cache_control}) if ( isinstance(file, BufferedReader) diff --git a/storage3/_sync/file_api.py b/storage3/_sync/file_api.py index 140c89a7..0aca339f 100644 --- a/storage3/_sync/file_api.py +++ b/storage3/_sync/file_api.py @@ -1,5 +1,7 @@ from __future__ import annotations +import base64 +import json import urllib.parse from dataclasses import dataclass, field from io import BufferedReader, FileIO @@ -435,11 +437,15 @@ def _upload_or_update( """ if file_options is None: file_options = {} - cache_control = file_options.get("cache-control") + cache_control = file_options.pop("cache-control", None) _data = {} - if file_options.get("upsert"): - file_options.update({"x-upsert": file_options.get("upsert")}) - del file_options["upsert"] + + upsert = file_options.pop("upsert", None) + if upsert: + file_options.update({"x-upsert": upsert}) + + metadata = file_options.pop("metadata", None) + file_opts_headers = file_options.pop("headers", None) headers = { **self._client.headers, @@ -447,6 +453,14 @@ def _upload_or_update( **file_options, } + if metadata: + metadata_str = json.dumps(metadata) + headers["x-metadata"] = base64.b64encode(metadata_str.encode()) + _data.update({"metadata": metadata_str}) + + if file_opts_headers: + headers.update({**file_opts_headers}) + # Only include x-upsert on a POST method if method != "POST": del headers["x-upsert"] @@ -455,7 +469,7 @@ def _upload_or_update( if cache_control: headers["cache-control"] = f"max-age={cache_control}" - _data = {"cacheControl": cache_control} + _data.update({"cacheControl": cache_control}) if ( isinstance(file, BufferedReader) diff --git a/storage3/types.py b/storage3/types.py index 5e86d3c6..ef4080c3 100644 --- a/storage3/types.py +++ b/storage3/types.py @@ -2,7 +2,7 @@ from dataclasses import asdict, dataclass from datetime import datetime -from typing import Literal, Optional, TypedDict, Union +from typing import Any, Dict, Literal, Optional, TypedDict, Union import dateutil.parser @@ -77,7 +77,14 @@ class DownloadOptions(TypedDict, total=False): FileOptions = TypedDict( "FileOptions", - {"cache-control": str, "content-type": str, "x-upsert": str, "upsert": str}, + { + "cache-control": str, + "content-type": str, + "x-upsert": str, + "upsert": str, + "metadata": Dict[str, Any], + "headers": Dict[str, str], + }, total=False, ) diff --git a/tests/_async/conftest.py b/tests/_async/conftest.py index 11069dbe..ec08ac3d 100644 --- a/tests/_async/conftest.py +++ b/tests/_async/conftest.py @@ -2,6 +2,7 @@ import asyncio import os +from collections.abc import AsyncGenerator, Generator import pytest from dotenv import load_dotenv @@ -14,13 +15,18 @@ def pytest_configure(config) -> None: @pytest.fixture(scope="package") -def event_loop() -> asyncio.AbstractEventLoop: +def event_loop() -> Generator[asyncio.AbstractEventLoop]: """Returns an event loop for the current thread""" - return asyncio.get_event_loop_policy().get_event_loop() + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + yield loop + loop.close() @pytest.fixture(scope="package") -async def storage() -> AsyncStorageClient: +async def storage() -> AsyncGenerator[AsyncStorageClient]: url = os.environ.get("SUPABASE_TEST_URL") assert url is not None, "Must provide SUPABASE_TEST_URL environment variable" key = os.environ.get("SUPABASE_TEST_KEY") diff --git a/tests/_async/test_client.py b/tests/_async/test_client.py index 38faf339..e156796b 100644 --- a/tests/_async/test_client.py +++ b/tests/_async/test_client.py @@ -398,8 +398,31 @@ async def test_client_get_public_url( assert response.content == file.file_content +async def test_client_upload_with_custom_metadata( + storage_file_client_public: AsyncBucketProxy, file: FileForTesting +) -> None: + """Ensure we can get the public url of a file in a bucket""" + await storage_file_client_public.upload( + file.bucket_path, + file.local_path, + { + "content-type": file.mime_type, + "metadata": {"custom": "metadata", "second": "second", "third": "third"}, + }, + ) + + info = await storage_file_client_public.info(file.bucket_path) + assert "metadata" in info.keys() + assert info["name"] == file.bucket_path + assert info["metadata"] == { + "custom": "metadata", + "second": "second", + "third": "third", + } + + async def test_client_info( - storage_file_client_public: SyncBucketProxy, file: FileForTesting + storage_file_client_public: AsyncBucketProxy, file: FileForTesting ) -> None: """Ensure we can get the public url of a file in a bucket""" await storage_file_client_public.upload( @@ -413,7 +436,7 @@ async def test_client_info( async def test_client_exists( - storage_file_client_public: SyncBucketProxy, file: FileForTesting + storage_file_client_public: AsyncBucketProxy, file: FileForTesting ) -> None: """Ensure we can get the public url of a file in a bucket""" await storage_file_client_public.upload( diff --git a/tests/_sync/conftest.py b/tests/_sync/conftest.py index 2c21ca16..fe5064a7 100644 --- a/tests/_sync/conftest.py +++ b/tests/_sync/conftest.py @@ -2,6 +2,7 @@ import asyncio import os +from collections.abc import Generator import pytest from dotenv import load_dotenv @@ -14,13 +15,18 @@ def pytest_configure(config) -> None: @pytest.fixture(scope="package") -def event_loop() -> asyncio.AbstractEventLoop: +def event_loop() -> Generator[asyncio.AbstractEventLoop]: """Returns an event loop for the current thread""" - return asyncio.get_event_loop_policy().get_event_loop() + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + yield loop + loop.close() @pytest.fixture(scope="package") -def storage() -> SyncStorageClient: +def storage() -> Generator[SyncStorageClient]: url = os.environ.get("SUPABASE_TEST_URL") assert url is not None, "Must provide SUPABASE_TEST_URL environment variable" key = os.environ.get("SUPABASE_TEST_KEY") diff --git a/tests/_sync/test_client.py b/tests/_sync/test_client.py index 06274099..83ffb4b1 100644 --- a/tests/_sync/test_client.py +++ b/tests/_sync/test_client.py @@ -396,6 +396,29 @@ def test_client_get_public_url( assert response.content == file.file_content +def test_client_upload_with_custom_metadata( + storage_file_client_public: SyncBucketProxy, file: FileForTesting +) -> None: + """Ensure we can get the public url of a file in a bucket""" + storage_file_client_public.upload( + file.bucket_path, + file.local_path, + { + "content-type": file.mime_type, + "metadata": {"custom": "metadata", "second": "second", "third": "third"}, + }, + ) + + info = storage_file_client_public.info(file.bucket_path) + assert "metadata" in info.keys() + assert info["name"] == file.bucket_path + assert info["metadata"] == { + "custom": "metadata", + "second": "second", + "third": "third", + } + + def test_client_info( storage_file_client_public: SyncBucketProxy, file: FileForTesting ) -> None: