Skip to content

Commit

Permalink
feat(toolkit): support multipart in v3 cdn (#325)
Browse files Browse the repository at this point in the history
feat(toolkit): multipart for cdn v3
  • Loading branch information
efiop authored Oct 8, 2024
1 parent 5e42ae4 commit 9aa6fc3
Showing 1 changed file with 192 additions and 5 deletions.
197 changes: 192 additions & 5 deletions projects/fal/src/fal/toolkit/file/providers/fal.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,17 @@ def is_expired(self) -> bool:
return datetime.now(timezone.utc) >= self.expires_at


class FalV3Token(FalV2Token):
pass


class FalV2TokenManager:
token_cls: type[FalV2Token] = FalV2Token
storage_type: str = "fal-cdn"
upload_prefix = "upload."

def __init__(self):
self._token: FalV2Token = FalV2Token(
self._token: FalV2Token = self.token_cls(
token="",
token_type="",
base_upload_url="",
Expand Down Expand Up @@ -63,7 +71,7 @@ def _refresh_token(self) -> None:

grpc_host = os.environ.get("FAL_HOST", "api.alpha.fal.ai")
rest_host = grpc_host.replace("api", "rest", 1)
url = f"https://{rest_host}/storage/auth/token"
url = f"https://{rest_host}/storage/auth/token?storage_type={self.storage_type}"

req = Request(
url,
Expand All @@ -76,18 +84,25 @@ def _refresh_token(self) -> None:

parsed_base_url = urlparse(result["base_url"])
base_upload_url = urlunparse(
parsed_base_url._replace(netloc="upload." + parsed_base_url.netloc)
parsed_base_url._replace(netloc=self.upload_prefix + parsed_base_url.netloc)
)

self._token = FalV2Token(
self._token = self.token_cls(
token=result["token"],
token_type=result["token_type"],
base_upload_url=base_upload_url,
expires_at=datetime.fromisoformat(result["expires_at"]),
)


class FalV3TokenManager(FalV2TokenManager):
token_cls: type[FalV2Token] = FalV3Token
storage_type: str = "fal-cdn-v3"
upload_prefix = ""


fal_v2_token_manager = FalV2TokenManager()
fal_v3_token_manager = FalV3TokenManager()


@dataclass
Expand Down Expand Up @@ -275,6 +290,128 @@ def complete(self):
return self._file_url


class MultipartUploadV3:
MULTIPART_THRESHOLD = 100 * 1024 * 1024
MULTIPART_CHUNK_SIZE = 10 * 1024 * 1024
MULTIPART_MAX_CONCURRENCY = 10

def __init__(
self,
file_path: str | Path,
chunk_size: int | None = None,
content_type: str | None = None,
max_concurrency: int | None = None,
) -> None:
self.file_path = file_path
self.chunk_size = chunk_size or self.MULTIPART_CHUNK_SIZE
self.content_type = content_type or "application/octet-stream"
self.max_concurrency = max_concurrency or self.MULTIPART_MAX_CONCURRENCY
self.access_url = None
self.upload_id = None

self._parts: list[dict] = []

@property
def auth_headers(self) -> dict[str, str]:
token = fal_v3_token_manager.get_token()
return {
"Authorization": f"{token.token_type} {token.token}",
"User-Agent": "fal/0.1.0",
}

def create(self):
token = fal_v3_token_manager.get_token()
try:
req = Request(
f"{token.base_upload_url}/files/upload/multipart",
method="POST",
headers={
**self.auth_headers,
"Accept": "application/json",
"Content-Type": self.content_type,
"X-Fal-File-Name": os.path.basename(self.file_path),
},
)
with urlopen(req) as response:
result = json.load(response)
self.access_url = result["access_url"]
self.upload_id = result["uploadId"]
except HTTPError as exc:
raise FileUploadException(
f"Error initiating upload. Status {exc.status}: {exc.reason}"
)

@retry(max_retries=5, base_delay=1, backoff_type="exponential", jitter=True)
def _upload_part(self, url: str, part_number: int) -> dict:
with open(self.file_path, "rb") as f:
start = (part_number - 1) * self.chunk_size
f.seek(start)
data = f.read(self.chunk_size)
req = Request(
url,
method="PUT",
headers={
**self.auth_headers,
"Content-Type": self.content_type,
},
data=data,
)

try:
with urlopen(req) as resp:
return {
"partNumber": part_number,
"etag": resp.headers["ETag"],
}
except HTTPError as exc:
raise FileUploadException(
f"Error uploading part {part_number} to {url}. "
f"Status {exc.status}: {exc.reason}"
)

def upload(self) -> None:
import concurrent.futures

parts = math.ceil(os.path.getsize(self.file_path) / self.chunk_size)
with concurrent.futures.ThreadPoolExecutor(
max_workers=self.max_concurrency
) as executor:
futures = []
for part_number in range(1, parts + 1):
upload_url = (
f"{self.access_url}/multipart/{self.upload_id}/{part_number}"
)
futures.append(
executor.submit(self._upload_part, upload_url, part_number)
)

for future in concurrent.futures.as_completed(futures):
entry = future.result()
self._parts.append(entry)

def complete(self):
url = f"{self.access_url}/multipart/{self.upload_id}/complete"
try:
req = Request(
url,
method="POST",
headers={
**self.auth_headers,
"Accept": "application/json",
"Content-Type": "application/json",
},
data=json.dumps({"parts": self._parts}).encode(),
)
with urlopen(req):
pass
except HTTPError as e:
raise FileUploadException(
f"Error completing upload {url}. Status {e.status}: {e.reason}"
)

return self.access_url


@dataclass
class FalFileRepositoryV2(FalFileRepositoryBase):
@retry(max_retries=3, base_delay=1, backoff_type="exponential", jitter=True)
Expand Down Expand Up @@ -451,8 +588,58 @@ def save(

@property
def auth_headers(self) -> dict[str, str]:
token = fal_v2_token_manager.get_token()
token = fal_v3_token_manager.get_token()
return {
"Authorization": f"{token.token_type} {token.token}",
"User-Agent": "fal/0.1.0",
}

def _save_multipart(
self,
file_path: str | Path,
chunk_size: int | None = None,
content_type: str | None = None,
max_concurrency: int | None = None,
) -> str:
multipart = MultipartUploadV3(
file_path,
chunk_size=chunk_size,
content_type=content_type,
max_concurrency=max_concurrency,
)
multipart.create()
multipart.upload()
return multipart.complete()

def save_file(
self,
file_path: str | Path,
content_type: str,
multipart: bool | None = None,
multipart_threshold: int | None = None,
multipart_chunk_size: int | None = None,
multipart_max_concurrency: int | None = None,
object_lifecycle_preference: dict[str, str] | None = None,
) -> tuple[str, FileData | None]:
if multipart is None:
threshold = multipart_threshold or MultipartUpload.MULTIPART_THRESHOLD
multipart = os.path.getsize(file_path) > threshold

if multipart:
url = self._save_multipart(
file_path,
chunk_size=multipart_chunk_size,
content_type=content_type,
max_concurrency=multipart_max_concurrency,
)
data = None
else:
with open(file_path, "rb") as f:
data = FileData(
f.read(),
content_type=content_type,
file_name=os.path.basename(file_path),
)
url = self.save(data, object_lifecycle_preference)

return url, data

0 comments on commit 9aa6fc3

Please sign in to comment.