Skip to content

Commit

Permalink
Optimize remote commands through StreamFlowPath
Browse files Browse the repository at this point in the history
This commit relies on the new `StreamFlowPath` abstraction to redirect
file-based commands to the lowest possible `ExecutionLocation` in the
wrapping hierarchy, in order to meet a `local` location whenever
possible. The main benefit of this strategy is that `local` locations
support Python-based commands, which are way faster than shell-based
remote processes.
  • Loading branch information
GlassOfWhiskey committed Dec 14, 2024
1 parent f83c0cb commit 97ae2e6
Showing 1 changed file with 56 additions and 8 deletions.
64 changes: 56 additions & 8 deletions streamflow/data/remotepath.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,40 @@ async def _size(
return int(result) if result.isdigit() else 0


def _traverse_bind_mounts(
location: ExecutionLocation, path: PurePath
) -> tuple[ExecutionLocation, pathlib.PurePath]:
while location.wraps:
for mount in sorted(location.mounts.keys(), reverse=True):
if str(path).startswith(mount):
path = (
Path(location.mounts[mount])
if location.wraps.local
else PurePosixPath(location.mounts[mount])
) / path.relative_to(mount)
location = location.wraps
break
else:
break
return location, path


class StreamFlowPath(PurePath, ABC):
def __new__(
cls, *args, context: StreamFlowContext, location: ExecutionLocation, **kwargs
):
if cls is StreamFlowPath:
location, path = _traverse_bind_mounts(
location=location,
path=(
Path(*args, **kwargs)
if location.local
else PurePosixPath(*args, **kwargs)
),
)
cls = LocalStreamFlowPath if location.local else RemoteStreamFlowPath
if sys.version_info < (3, 12):
return cls._from_parsed_parts(path._drv, path._root, path._parts)
if sys.version_info < (3, 12):
return cls._from_parts(args)
else:
Expand Down Expand Up @@ -254,13 +282,22 @@ def __init__(
self,
*args,
context: StreamFlowContext,
location: ExecutionLocation | None = None,
location: ExecutionLocation,
):
location, path = _traverse_bind_mounts(
location=location,
path=(Path(*args) if location.local else PurePosixPath(*args)),
)
if not location.local:
raise WorkflowExecutionException(
f"{self.__class__.__name__} should only be used on a local path."
)
if sys.version_info < (3, 12):
super().__init__()
else:
super().__init__(*args)
super().__init__(path)
self.context: StreamFlowContext = context
self.location: ExecutionLocation = location

async def checksum(self) -> str | None:
if await self.is_file():
Expand All @@ -278,7 +315,9 @@ async def glob(
self, pattern, *, case_sensitive=None
) -> AsyncIterator[LocalStreamFlowPath]:
for path in glob.glob(str(self / pattern)):
yield LocalStreamFlowPath(path, context=self.context)
yield LocalStreamFlowPath(
path, context=self.context, location=self.location
)

async def is_dir(self) -> bool:
return cast(Path, super()).is_dir()
Expand Down Expand Up @@ -330,7 +369,9 @@ async def read_text(self, n=-1, encoding=None, errors=None) -> str:
async def resolve(self, strict=False) -> LocalStreamFlowPath | None:
if await self.exists():
return LocalStreamFlowPath(
super().resolve(strict=strict), context=self.context
super().resolve(strict=strict),
context=self.context,
location=self.location,
)
else:
return None
Expand Down Expand Up @@ -376,7 +417,7 @@ async def walk(
yield dirpath, dirnames, filenames

def with_segments(self, *pathsegments):
return type(self)(*pathsegments, context=self.context)
return type(self)(*pathsegments, context=self.context, location=self.location)

async def write_text(self, data: str, **kwargs) -> int:
return cast(Path, super()).write_text(data=data, **kwargs)
Expand All @@ -389,10 +430,17 @@ class RemoteStreamFlowPath(
__slots__ = ("context", "connector", "location")

def __init__(self, *args, context: StreamFlowContext, location: ExecutionLocation):
location, path = _traverse_bind_mounts(
location=location, path=PurePosixPath(*args)
)
if sys.version_info < (3, 12):
super().__init__()
else:
super().__init__(*args)
super().__init__(path)
if location.local:
raise WorkflowExecutionException(
f"{self.__class__.__name__} should not be used on a local path."
)
self.context: StreamFlowContext = context
self.connector: Connector = self.context.deployment_manager.get_connector(
location.deployment
Expand Down Expand Up @@ -501,15 +549,15 @@ async def read_text(self, n=-1, encoding=None, errors=None) -> str:
_check_status(command, self.location, result, status)
return result.strip()

async def resolve(self, strict=False) -> RemoteStreamFlowPath | None:
async def resolve(self, strict=False) -> StreamFlowPath | None:
# If at least one primary location is present on the site, return its path
if locations := self.context.data_manager.get_data_locations(
path=self.__str__(),
deployment=self.connector.deployment_name,
location_name=self.location.name,
data_type=DataType.PRIMARY,
):
return RemoteStreamFlowPath(
return StreamFlowPath(
next(iter(locations)).path,
context=self.context,
location=next(iter(locations)).location,
Expand Down

0 comments on commit 97ae2e6

Please sign in to comment.