Skip to content

Commit

Permalink
File downloader now supports simple caching.
Browse files Browse the repository at this point in the history
  • Loading branch information
Mr-Milk committed Sep 5, 2024
1 parent 10e0b20 commit 0c96ffa
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 40 deletions.
111 changes: 83 additions & 28 deletions src/wsi_data/io/_download.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from pathlib import Path

import fsspec
from rich.progress import (
Progress,
Expand All @@ -9,35 +11,88 @@
)


class Downloader:
def __init__(self, url, dest):
class CacheDownloader:
"""A downloader with simple cache mechanism.
This class can download from arbitrary URLs and cache the downloaded file.
The hash of the downloaded file is stored in a hidden file
in the same directory of the downloaded file.
Parameters
----------
url : str
URL to download.
name : str, optional
Name of the file, by default None.
cache_dir : str, optional
Cache directory, by default None.
"""

def __init__(self, url, name=None, cache_dir=None):
self.url = url
self.dest = dest
with fsspec.open(self.url, "rb") as fsrc:
self.total_size = fsrc.size # Retrieve the total file size
if name is None:
if hasattr(fsrc, "name"):
name = fsrc.name
if name is None:
name = Path(url).name
if cache_dir is None:
cache_dir = "." # Default to current directory
cache_dir = Path(cache_dir)
self.name = name
self.dest = cache_dir / name
self.hash_path = cache_dir / f".{name}.hash"

def is_cache(self):
if self.dest.exists() and self.hash_path.exists():
with open(self.hash_path, "r") as f:
last_file_hash = f.read()
with open(self.dest, "rb") as f:
current_file_hash = self._hash_file(f)
return last_file_hash == current_file_hash
return False

@staticmethod
def _hash_file(fileobj):
import hashlib

digest = hashlib.file_digest(fileobj, "sha256")
return digest.hexdigest()

def download(self, pbar=True):
"""Download a single file with progress tracking."""
progress = Progress(
TextColumn("[bold blue]{task.fields[filename]}", justify="right"),
BarColumn(bar_width=20),
"[progress.percentage]{task.percentage:>3.1f}%",
"•",
DownloadColumn(),
"•",
TransferSpeedColumn(),
"•",
TimeRemainingColumn(),
disable=not pbar,
)
with progress:
with fsspec.open(self.url, "rb") as fsrc:
total_size = fsrc.size # Retrieve the total file size
task_id = progress.add_task(
"Downloading test", filename=self.dest, total=total_size
)
with fsspec.open(f"{self.dest}", "wb") as fdst:
progress.start_task(task_id)
chunk_size = 1024 * 1024 # 1 MB
while chunk := fsrc.read(chunk_size):
fdst.write(chunk)
progress.advance(task_id, chunk_size)
progress.refresh()
if self.is_cache():
return self.dest
else:
progress = Progress(
TextColumn("[bold blue]{task.fields[filename]}", justify="right"),
BarColumn(bar_width=20),
"[progress.percentage]{task.percentage:>3.1f}%",
"•",
DownloadColumn(),
"•",
TransferSpeedColumn(),
"•",
TimeRemainingColumn(),
disable=not pbar,
)
with progress:
with fsspec.open(self.url, "rb") as fsrc:
task_id = progress.add_task(
"Downloading test", filename=self.dest, total=self.total_size
)
with fsspec.open(f"{self.dest}", "wb") as fdst:
progress.start_task(task_id)
chunk_size = 1024 * 1024 # 1 MB
while chunk := fsrc.read(chunk_size):
fdst.write(chunk)
progress.advance(task_id, chunk_size)
progress.refresh()
# Create a hash file
with open(self.hash_path, "w") as f:
with open(self.dest, "rb") as fdst:
f.write(self._hash_file(fdst))
return self.dest
17 changes: 5 additions & 12 deletions src/wsi_data/io/_open_wsi.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,17 @@
from wsi_data.data import WSIData
from wsi_data.reader import get_reader

from ._download import Downloader
from ._download import CacheDownloader


def open_wsi(
wsi,
backed_file=None,
reader=None,
download=True,
pbar=True,
cache_dir=None,
name=None,
cache_dir=None,
pbar=True,
attach_images=False,
image_key="wsi",
save_images=False,
Expand Down Expand Up @@ -59,20 +59,13 @@ def open_wsi(
fs, wsi_path = url_to_fs(wsi)
if not fs.exists(wsi_path):
raise ValueError(f"Slide {wsi} not existed or not accessible.")
if name is None:
name = Path(wsi_path).name

# Early attempt with reader
ReaderCls = get_reader(reader)

if download and fs.protocol != "file":
if cache_dir is None:
cache_dir = Path("lazyslide_downloads")
if not Path(cache_dir).exists():
Path(cache_dir).mkdir(parents=True, exist_ok=True)
wsi = Path(cache_dir) / name
downloader = Downloader(wsi_path, wsi)
downloader.download(pbar)
downloader = CacheDownloader(wsi_path, name=name, cache_dir=cache_dir)
wsi = downloader.download(pbar)

reader_obj = ReaderCls(wsi)
wsi = Path(wsi)
Expand Down

0 comments on commit 0c96ffa

Please sign in to comment.