Skip to content

Commit

Permalink
feat: process object lifecycle preference from request (#318)
Browse files Browse the repository at this point in the history
* feat: process object lifecycle

* feat: process object lifecycle pref

* properly parse preferences
  • Loading branch information
mederka authored Oct 1, 2024
1 parent e943226 commit 4d4a500
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 11 deletions.
33 changes: 31 additions & 2 deletions projects/fal/src/fal/toolkit/file/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from zipfile import ZipFile

import pydantic
from fastapi import Request

# https://github.com/pydantic/pydantic/pull/2573
if not hasattr(pydantic, "__version__") or pydantic.__version__.startswith("1."):
Expand Down Expand Up @@ -55,6 +56,7 @@ def get_builtin_repository(id: RepositoryId) -> FileRepository:

DEFAULT_REPOSITORY: FileRepository | RepositoryId = "fal_v2"
FALLBACK_REPOSITORY: FileRepository | RepositoryId = "cdn"
OBJECT_LIFECYCLE_PREFERENCE_KEY = "x-fal-object-lifecycle-preference"


class File(BaseModel):
Expand Down Expand Up @@ -132,6 +134,7 @@ def from_bytes(
fallback_repository: Optional[
FileRepository | RepositoryId
] = FALLBACK_REPOSITORY,
request: Optional[Request] = None,
) -> File:
repo = (
repository
Expand All @@ -141,8 +144,10 @@ def from_bytes(

fdata = FileData(data, content_type, file_name)

object_lifecycle_preference = _get_lifecycle_preference(request)

try:
url = repo.save(fdata)
url = repo.save(fdata, object_lifecycle_preference)
except Exception:
if not fallback_repository:
raise
Expand All @@ -153,7 +158,7 @@ def from_bytes(
else get_builtin_repository(fallback_repository)
)

url = fallback_repo.save(fdata)
url = fallback_repo.save(fdata, object_lifecycle_preference)

return cls(
url=url,
Expand All @@ -173,6 +178,7 @@ def from_path(
fallback_repository: Optional[
FileRepository | RepositoryId
] = FALLBACK_REPOSITORY,
request: Optional[Request] = None,
) -> File:
file_path = Path(path)
if not file_path.exists():
Expand All @@ -185,12 +191,14 @@ def from_path(
)

content_type = content_type or "application/octet-stream"
object_lifecycle_preference = _get_lifecycle_preference(request)

try:
url, data = repo.save_file(
file_path,
content_type=content_type,
multipart=multipart,
object_lifecycle_preference=object_lifecycle_preference,
)
except Exception:
if not fallback_repository:
Expand All @@ -206,6 +214,7 @@ def from_path(
file_path,
content_type=content_type,
multipart=multipart,
object_lifecycle_preference=object_lifecycle_preference,
)

return cls(
Expand Down Expand Up @@ -263,3 +272,23 @@ def glob(self, pattern: str):
def __del__(self):
if self.extract_dir:
shutil.rmtree(self.extract_dir)


def _get_lifecycle_preference(request: Request) -> dict[str, str] | None:
import json

preference_str = (
request.headers.get(OBJECT_LIFECYCLE_PREFERENCE_KEY)
if request is not None
else None
)
if preference_str is None:
return None

object_lifecycle_preference = {}
try:
object_lifecycle_preference = json.loads(preference_str)
return object_lifecycle_preference
except Exception as e:
print(f"Failed to parse object lifecycle preference: {e}")
return None
28 changes: 22 additions & 6 deletions projects/fal/src/fal/toolkit/file/providers/fal.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,9 @@ def _upload_file(self, upload_url: str, file: FileData):

@dataclass
class FalFileRepository(FalFileRepositoryBase):
def save(self, file: FileData) -> str:
def save(
self, file: FileData, object_lifecycle_preference: dict[str, str] | None = None
) -> str:
return self._save(file, "gcs")


Expand Down Expand Up @@ -276,7 +278,9 @@ def complete(self):
@dataclass
class FalFileRepositoryV2(FalFileRepositoryBase):
@retry(max_retries=3, base_delay=1, backoff_type="exponential", jitter=True)
def save(self, file: FileData) -> str:
def save(
self, file: FileData, object_lifecycle_preference: dict[str, str] | None = None
) -> str:
token = fal_v2_token_manager.get_token()
headers = {
"Authorization": f"{token.token_type} {token.token}",
Expand Down Expand Up @@ -328,6 +332,7 @@ def save_file(
multipart_threshold: int | None = None,
multipart_chunk_size: int | None = None,
multipart_max_concurrency: int | None = None,
object_lifecycle_preference: dict[str, str] | None = None,
) -> tuple[str, FileData | None]:
if multipart is None:
threshold = multipart_threshold or MultipartUpload.MULTIPART_THRESHOLD
Expand All @@ -348,7 +353,7 @@ def save_file(
content_type=content_type,
file_name=os.path.basename(file_path),
)
url = self.save(data)
url = self.save(data, object_lifecycle_preference)

return url, data

Expand All @@ -358,6 +363,7 @@ class InMemoryRepository(FileRepository):
def save(
self,
file: FileData,
object_lifecycle_preference: dict[str, str] | None = None,
) -> str:
return f'data:{file.content_type};base64,{b64encode(file.data).decode("utf-8")}'

Expand All @@ -368,6 +374,7 @@ class FalCDNFileRepository(FileRepository):
def save(
self,
file: FileData,
object_lifecycle_preference: dict[str, str] | None = None,
) -> str:
headers = {
**self.auth_headers,
Expand Down Expand Up @@ -408,16 +415,25 @@ def auth_headers(self) -> dict[str, str]:
class FalFileRepositoryV3(FileRepository):
@retry(max_retries=3, base_delay=1, backoff_type="exponential", jitter=True)
def save(
self,
file: FileData,
self, file: FileData, user_lifecycle_preference: dict[str, str] | None
) -> str:
object_lifecycle_preference = dataclasses.asdict(GLOBAL_LIFECYCLE_PREFERENCE)

if user_lifecycle_preference is not None:
object_lifecycle_preference = {
key: user_lifecycle_preference[key]
if key in user_lifecycle_preference
else value
for key, value in object_lifecycle_preference.items()
}

headers = {
**self.auth_headers,
"Accept": "application/json",
"Content-Type": file.content_type,
"X-Fal-File-Name": file.file_name,
"X-Fal-Object-Lifecycle-Preference": json.dumps(
dataclasses.asdict(GLOBAL_LIFECYCLE_PREFERENCE)
object_lifecycle_preference
),
}
url = os.getenv("FAL_CDN_V3_HOST", _FAL_CDN_V3) + "/files/upload"
Expand Down
11 changes: 8 additions & 3 deletions projects/fal/src/fal/toolkit/file/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dataclasses import dataclass
from mimetypes import guess_extension, guess_type
from pathlib import Path
from typing import Literal
from typing import Literal, Optional
from uuid import uuid4


Expand Down Expand Up @@ -36,7 +36,11 @@ def __init__(

@dataclass
class FileRepository:
def save(self, data: FileData) -> str:
def save(
self,
data: FileData,
object_lifecycle_preference: Optional[dict[str, str]] = None,
) -> str:
raise NotImplementedError()

def save_file(
Expand All @@ -47,11 +51,12 @@ def save_file(
multipart_threshold: int | None = None,
multipart_chunk_size: int | None = None,
multipart_max_concurrency: int | None = None,
object_lifecycle_preference: Optional[dict[str, str]] = None,
) -> tuple[str, FileData | None]:
if multipart:
raise NotImplementedError()

with open(file_path, "rb") as fobj:
data = FileData(fobj.read(), content_type, Path(file_path).name)

return self.save(data), data
return self.save(data, object_lifecycle_preference), data
5 changes: 5 additions & 0 deletions projects/fal/src/fal/toolkit/image/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING, Literal, Optional, Union

from fastapi import Request
from pydantic import BaseModel, Field

from fal.toolkit.file.file import DEFAULT_REPOSITORY, FALLBACK_REPOSITORY, File
Expand Down Expand Up @@ -82,13 +83,15 @@ def from_bytes( # type: ignore[override]
fallback_repository: Optional[
FileRepository | RepositoryId
] = FALLBACK_REPOSITORY,
request: Optional[Request] = None,
) -> Image:
obj = super().from_bytes(
data,
content_type=f"image/{format}",
file_name=file_name,
repository=repository,
fallback_repository=fallback_repository,
request=request,
)
obj.width = size.width if size else None
obj.height = size.height if size else None
Expand All @@ -104,6 +107,7 @@ def from_pil(
fallback_repository: Optional[
FileRepository | RepositoryId
] = FALLBACK_REPOSITORY,
request: Optional[Request] = None,
) -> Image:
size = ImageSize(width=pil_image.width, height=pil_image.height)
if format is None:
Expand Down Expand Up @@ -133,6 +137,7 @@ def from_pil(
file_name,
repository,
fallback_repository=fallback_repository,
request=request,
)

def to_pil(self, mode: str = "RGB") -> PILImage.Image:
Expand Down

0 comments on commit 4d4a500

Please sign in to comment.