Skip to content

Commit

Permalink
refactored tarball creation/extraction to use create_tarball/`extra…
Browse files Browse the repository at this point in the history
…ct_tarball`
  • Loading branch information
telamonian committed Aug 30, 2024
1 parent b6097f6 commit 0cfb7df
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 63 deletions.
68 changes: 11 additions & 57 deletions comfy_cli/standalone.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
import shutil
import subprocess
import tarfile
from pathlib import Path
from typing import Optional

import requests
from rich.live import Live
from rich.progress import Progress, TextColumn
from rich.table import Table

from comfy_cli.constants import OS, PROC
from comfy_cli.typing import PathLike
from comfy_cli.utils import download_progress, get_os, get_proc
from comfy_cli.utils import create_tarball, download_url, extract_tarball, get_os, get_proc
from comfy_cli.uv import DependencyCompiler

_here = Path(__file__).expanduser().resolve().parent
Expand All @@ -36,6 +32,7 @@ def download_standalone_python(
tag: str = "latest",
flavor: str = "install_only",
cwd: PathLike = ".",
show_progress: bool = True,
) -> PathLike:
"""grab a pre-built distro from the python-build-standalone project. See
https://gregoryszorc.com/docs/python-build-standalone/main/"""
Expand All @@ -60,7 +57,7 @@ def download_standalone_python(
fname = f"{name}.tar.gz"
url = f"{asset_url_prefix}/{fname}"

return download_progress(url, fname, cwd=cwd)
return download_url(url, fname, cwd=cwd, show_progress=show_progress)


class StandalonePython:
Expand All @@ -73,36 +70,25 @@ def FromDistro(
flavor: str = "install_only",
cwd: PathLike = ".",
name: PathLike = "python",
):
show_progress: bool = True,
) -> "StandalonePython":
fpath = download_standalone_python(
platform=platform,
proc=proc,
version=version,
tag=tag,
flavor=flavor,
cwd=cwd,
show_progress=show_progress,
)
return StandalonePython.FromTarball(fpath, name)

@staticmethod
def FromTarball(fpath: PathLike, name: PathLike = "python") -> "StandalonePython":
def FromTarball(fpath: PathLike, name: PathLike = "python", show_progress: bool = True) -> "StandalonePython":
fpath = Path(fpath)

with tarfile.open(fpath) as tar:
info = tar.next()
old_name = info.name.split("/")[0]

old_rpath = fpath.parent / old_name
rpath = fpath.parent / name

# clean the tar file expand target and the final target
shutil.rmtree(old_rpath, ignore_errors=True)
shutil.rmtree(rpath, ignore_errors=True)

with tarfile.open(fpath) as tar:
tar.extractall()

shutil.move(old_rpath, rpath)
extract_tarball(inPath=fpath, outPath=rpath, show_progress=show_progress)

return StandalonePython(rpath=rpath)

Expand Down Expand Up @@ -177,40 +163,8 @@ def rehydrate_comfy_deps(self):
)
self.dep_comp.install_wheels_directly()

def to_tarball(self, outPath: Optional[PathLike] = None, progress: bool = True):
outPath = self.rpath.with_suffix(".tgz") if outPath is None else Path(outPath)

# do a little clean up prep
outPath.unlink(missing_ok=True)
def to_tarball(self, outPath: Optional[PathLike] = None, show_progress: bool = True):
# remove any __pycache__ before creating archive
self.clean()

if progress:
fileSize = sum(f.stat().st_size for f in self.rpath.glob("**/*"))

barProg = Progress()
addTar = barProg.add_task("[cyan]Creating tarball...", total=fileSize)
pathProg = Progress(TextColumn("{task.description}"))
pathTar = pathProg.add_task("")

progress_table = Table.grid()
progress_table.add_row(barProg)
progress_table.add_row(pathProg)

_size = 0

def _filter(tinfo: tarfile.TarInfo):
nonlocal _size
pathProg.update(pathTar, description=tinfo.path)
barProg.advance(addTar, _size)
_size = Path(tinfo.path).stat().st_size
return tinfo
else:
_filter = None

with Live(progress_table, refresh_per_second=10):
with tarfile.open(outPath, "w:gz") as tar:
tar.add(self.rpath.relative_to(Path(".").expanduser().resolve()), filter=_filter)

if progress:
barProg.advance(addTar, _size)
pathProg.update(pathTar, description="")
create_tarball(inPath=self.rpath, outPath=outPath, show_progress=show_progress)
124 changes: 118 additions & 6 deletions comfy_cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,16 @@
import platform
import shutil
import subprocess
import tarfile
from pathlib import Path
from typing import Optional

import psutil
import requests
import typer
from rich import print, progress
from rich.live import Live
from rich.table import Table

from comfy_cli.constants import DEFAULT_COMFY_WORKSPACE, OS, PROC
from comfy_cli.typing import PathLike
Expand Down Expand Up @@ -100,7 +104,13 @@ def f(incomplete: str) -> list[str]:
return f


def download_progress(url: str, fname: PathLike, cwd: PathLike = ".", allow_redirects: bool = True) -> PathLike:
def download_url(
url: str,
fname: PathLike,
cwd: PathLike = ".",
allow_redirects: bool = True,
show_progress: bool = True,
) -> PathLike:
"""download url to local file fname and show a progress bar.
See https://stackoverflow.com/q/37573483"""
cwd = Path(cwd).expanduser().resolve()
Expand All @@ -110,12 +120,114 @@ def download_progress(url: str, fname: PathLike, cwd: PathLike = ".", allow_redi
if response.status_code != 200:
response.raise_for_status() # Will only raise for 4xx codes, so...
raise RuntimeError(f"Request to {url} returned status code {response.status_code}")
fsize = int(response.headers.get("Content-Length", 0))

desc = "(Unknown total file size)" if fsize == 0 else ""
response.raw.read = functools.partial(response.raw.read, decode_content=True) # Decompress if needed
with progress.wrap_file(response.raw, total=fsize, description=desc) as response_raw:
with fpath.open("wb") as f:
shutil.copyfileobj(response_raw, f)
with fpath.open("wb") as f:
if show_progress:
fsize = int(response.headers.get("Content-Length", 0))
desc = f"downloading {fname}..." + ("(Unknown total file size)" if fsize == 0 else "")

with progress.wrap_file(response.raw, total=fsize, description=desc) as response_raw:
shutil.copyfileobj(response_raw, f)
else:
shutil.copyfileobj(response.raw, f)

return fpath


def extract_tarball(
inPath: PathLike,
outPath: Optional[PathLike] = None,
show_progress: bool = True,
):
inPath = Path(inPath).expanduser().resolve()
outPath = inPath.with_suffix("") if outPath is None else Path(outPath).expanduser().resolve()

with tarfile.open(inPath) as tar:
info = tar.next()
old_name = info.name.split("/")[0]
# path to top-level of extraction result
extractPath = inPath.with_name(old_name)

# clean both the extraction path and the final target path
shutil.rmtree(extractPath, ignore_errors=True)
shutil.rmtree(outPath, ignore_errors=True)

if show_progress:
fileSize = inPath.stat().st_size

barProg = progress.Progress()
barTask = barProg.add_task("[cyan]extracting tarball...", total=fileSize)
pathProg = progress.Progress(progress.TextColumn("{task.description}"))
pathTask = pathProg.add_task("")

progress_table = Table.grid()
progress_table.add_row(barProg)
progress_table.add_row(pathProg)

_size = 0

def _filter(tinfo: tarfile.TarInfo, _path: Optional[PathLike] = None):
nonlocal _size
pathProg.update(pathTask, description=tinfo.path)
barProg.advance(barTask, _size)
_size = tinfo.size
return tinfo
else:
_filter = None

with Live(progress_table, refresh_per_second=10):
with tarfile.open(inPath) as tar:
tar.extractall(filter=_filter)

Check failure

Code scanning / CodeQL

Arbitrary file write during tarfile extraction High

This file extraction depends on a
potentially untrusted source
.

if show_progress:
barProg.advance(barTask, _size)
pathProg.update(pathTask, description="")

shutil.move(extractPath, outPath)


def create_tarball(
inPath: PathLike,
outPath: Optional[PathLike] = None,
cwd: Optional[PathLike] = None,
show_progress: bool = True,
):
cwd = Path("." if cwd is None else cwd).expanduser().resolve()
inPath = Path(inPath).expanduser().resolve()
outPath = inPath.with_suffix(".tgz") if outPath is None else Path(outPath).expanduser().resolve()

# clean the archive target path
outPath.unlink(missing_ok=True)

if show_progress:
fileSize = sum(f.stat().st_size for f in inPath.glob("**/*"))

barProg = progress.Progress()
barTask = barProg.add_task("[cyan]creating tarball...", total=fileSize)
pathProg = progress.Progress(progress.TextColumn("{task.description}"))
pathTask = pathProg.add_task("")

progress_table = Table.grid()
progress_table.add_row(barProg)
progress_table.add_row(pathProg)

_size = 0

def _filter(tinfo: tarfile.TarInfo):
nonlocal _size
pathProg.update(pathTask, description=tinfo.path)
barProg.advance(barTask, _size)
_size = Path(tinfo.path).stat().st_size
return tinfo
else:
_filter = None

with Live(progress_table, refresh_per_second=10):
with tarfile.open(outPath, "w:gz") as tar:
# don't include parent paths in archive
tar.add(inPath.relative_to(cwd), filter=_filter)

if show_progress:
barProg.advance(barTask, _size)
pathProg.update(pathTask, description="")

0 comments on commit 0cfb7df

Please sign in to comment.