Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add wrap_file for wrapping a file object with callback #271

Merged
merged 1 commit into from
Jan 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions src/dvc_objects/fs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
Sequence,
Tuple,
Union,
cast,
overload,
)
from urllib.parse import urlsplit, urlunsplit
Expand All @@ -34,8 +33,8 @@

from .callbacks import (
DEFAULT_CALLBACK,
CallbackStream,
wrap_and_branch_callback,
wrap_file,
)
from .errors import RemoteMissingDepsError

Expand Down Expand Up @@ -637,7 +636,7 @@ def put_file(
if size:
callback.set_size(size)
if hasattr(from_file, "read"):
stream = cast("BinaryIO", CallbackStream(from_file, callback))
stream = wrap_file(from_file, callback)
self.upload_fobj(stream, to_info, size=size)
else:
assert isinstance(from_file, str)
Expand Down
32 changes: 13 additions & 19 deletions src/dvc_objects/fs/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,28 @@
import asyncio
from functools import wraps
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, TypeVar
from typing import TYPE_CHECKING, Any, BinaryIO, Callable, Dict, Optional, TypeVar, cast

import fsspec

if TYPE_CHECKING:
from typing import BinaryIO, Union
from typing import Union

from dvc_objects._tqdm import Tqdm

F = TypeVar("F", bound=Callable)


class CallbackStream:
def __init__(self, stream, callback, method="read"):
def __init__(self, stream, callback: fsspec.Callback):
self.stream = stream
if method == "write":

@wraps(stream.write)
def write(data, *args, **kwargs):
res = stream.write(data, *args, **kwargs)
callback.relative_update(len(data))
return res
@wraps(stream.read)
def read(*args, **kwargs):
data = stream.read(*args, **kwargs)
callback.relative_update(len(data))
return data

self.write = write
else:

@wraps(stream.read)
def read(*args, **kwargs):
data = stream.read(*args, **kwargs)
callback.relative_update(len(data))
return data

self.read = read
self.read = read

def __getattr__(self, attr):
return getattr(self.stream, attr)
Expand Down Expand Up @@ -181,4 +171,8 @@ def wrap_and_branch_callback(callback: fsspec.Callback, fn: F) -> F:
return wrap_fn(callback, branch_wrapper)


def wrap_file(file, callback: fsspec.Callback) -> BinaryIO:
return cast(BinaryIO, CallbackStream(file, callback))


DEFAULT_CALLBACK = NoOpCallback()
6 changes: 3 additions & 3 deletions src/dvc_objects/fs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from dvc_objects.executors import ThreadPoolExecutor

from . import system
from .callbacks import DEFAULT_CALLBACK, CallbackStream
from .callbacks import DEFAULT_CALLBACK, wrap_file

if TYPE_CHECKING:
from .base import AnyFSPath, FileSystem
Expand Down Expand Up @@ -168,8 +168,8 @@ def copyfile(

callback.set_size(total)
with open(src, "rb") as fsrc, open(dest, "wb+") as fdest:
wrapped = CallbackStream(fdest, callback, "write")
shutil.copyfileobj(fsrc, wrapped, length=LOCAL_CHUNK_SIZE)
wrapped = wrap_file(fsrc, callback)
shutil.copyfileobj(wrapped, fdest, length=LOCAL_CHUNK_SIZE)


def tmp_fname(prefix: str = "") -> str:
Expand Down
15 changes: 15 additions & 0 deletions tests/fs/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
TqdmCallback,
branch_callback,
wrap_and_branch_callback,
wrap_file,
wrap_fn,
)

Expand Down Expand Up @@ -146,3 +147,17 @@ async def test_wrap_and_branch_callback_async(mocker, cb_class):
m.assert_any_call("argA", "argB", arg3="argC", callback=IsDVCCallback())
assert callback.value == 2
assert spy.call_count == 2


def test_wrap_file(memfs):
memfs.pipe_file("/file", b"foo\n")

callback = Callback()

callback.set_size(4)
with memfs.open("/file", mode="rb") as f:
wrapped = wrap_file(f, callback)
assert wrapped.read() == b"foo\n"

assert callback.value == 4
assert callback.size == 4