Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Misc. small fixes and improvements #2136

Merged
merged 8 commits into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 60 additions & 45 deletions rastervision_aws_s3/rastervision/aws_s3/s3_file_system.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Tuple
from typing import Any, Iterator, Tuple
import io
import os
import subprocess
Expand All @@ -16,41 +16,38 @@


# Code from https://alexwlchan.net/2017/07/listing-s3-keys/
def get_matching_s3_objects(bucket, prefix='', suffix='',
request_payer='None'):
"""
Generate objects in an S3 bucket.

:param bucket: Name of the S3 bucket.
:param prefix: Only fetch objects whose key starts with
this prefix (optional).
:param suffix: Only fetch objects whose keys end with
this suffix (optional).
def get_matching_s3_objects(
bucket: str,
prefix: str = '',
suffix: str = '',
delimiter: str = '/',
request_payer: str = 'None') -> Iterator[tuple[str, Any]]:
"""Generate objects in an S3 bucket.

Args:
bucket: Name of the S3 bucket.
prefix: Only fetch objects whose key starts with this prefix.
suffix: Only fetch objects whose keys end with this suffix.
"""
s3 = S3FileSystem.get_client()
kwargs = {'Bucket': bucket, 'RequestPayer': request_payer}

# If the prefix is a single string (not a tuple of strings), we can
# do the filtering directly in the S3 API.
if isinstance(prefix, str):
kwargs['Prefix'] = prefix

kwargs = dict(
Bucket=bucket,
RequestPayer=request_payer,
Delimiter=delimiter,
Prefix=prefix,
)
while True:

# The S3 API response is a large blob of metadata.
# 'Contents' contains information about the listed objects.
resp = s3.list_objects_v2(**kwargs)

try:
contents = resp['Contents']
except KeyError:
return

for obj in contents:
resp: dict = s3.list_objects_v2(**kwargs)
dirs: list[dict] = resp.get('CommonPrefixes', {})
files: list[dict] = resp.get('Contents', {})
for obj in dirs:
key = obj['Prefix']
if key.startswith(prefix) and key.endswith(suffix):
yield key, obj
for obj in files:
key = obj['Key']
if key.startswith(prefix) and key.endswith(suffix):
yield obj

yield key, obj
# The S3 API is paginated, returning up to 1000 keys at a time.
# Pass the continuation token into the next response, until we
# reach the final page (when this field is missing).
Expand All @@ -60,16 +57,26 @@ def get_matching_s3_objects(bucket, prefix='', suffix='',
break


def get_matching_s3_keys(bucket, prefix='', suffix='', request_payer='None'):
"""
Generate the keys in an S3 bucket.
def get_matching_s3_keys(bucket: str,
prefix: str = '',
suffix: str = '',
delimiter: str = '/',
request_payer: str = 'None') -> Iterator[str]:
"""Generate the keys in an S3 bucket.

:param bucket: Name of the S3 bucket.
:param prefix: Only fetch keys that start with this prefix (optional).
:param suffix: Only fetch keys that end with this suffix (optional).
Args:
bucket: Name of the S3 bucket.
prefix: Only fetch keys that start with this prefix.
suffix: Only fetch keys that end with this suffix.
"""
for obj in get_matching_s3_objects(bucket, prefix, suffix, request_payer):
yield obj['Key']
obj_iterator = get_matching_s3_objects(
bucket,
prefix=prefix,
suffix=suffix,
delimiter=delimiter,
request_payer=request_payer)
out = (key for key, _ in obj_iterator)
return out


def progressbar(total_size: int, desc: str):
Expand Down Expand Up @@ -180,8 +187,9 @@ def read_bytes(uri: str) -> bytes:
bucket, key = S3FileSystem.parse_uri(uri)
with io.BytesIO() as file_buffer:
try:
file_size = s3.head_object(
Bucket=bucket, Key=key)['ContentLength']
obj = s3.head_object(
Bucket=bucket, Key=key, RequestPayer=request_payer)
file_size = obj['ContentLength']
with progressbar(file_size, desc='Downloading') as bar:
s3.download_fileobj(
Bucket=bucket,
Expand Down Expand Up @@ -256,7 +264,9 @@ def copy_from(src_uri: str, dst_path: str) -> None:
request_payer = S3FileSystem.get_request_payer()
bucket, key = S3FileSystem.parse_uri(src_uri)
try:
file_size = s3.head_object(Bucket=bucket, Key=key)['ContentLength']
obj = s3.head_object(
Bucket=bucket, Key=key, RequestPayer=request_payer)
file_size = obj['ContentLength']
with progressbar(file_size, desc=f'Downloading') as bar:
s3.download_file(
Bucket=bucket,
Expand Down Expand Up @@ -284,11 +294,16 @@ def last_modified(uri: str) -> datetime:
return head_data['LastModified']

@staticmethod
def list_paths(uri, ext=''):
def list_paths(uri: str, ext: str = '', delimiter: str = '/') -> list[str]:
request_payer = S3FileSystem.get_request_payer()
parsed_uri = urlparse(uri)
bucket = parsed_uri.netloc
prefix = os.path.join(parsed_uri.path[1:])
keys = get_matching_s3_keys(
bucket, prefix, suffix=ext, request_payer=request_payer)
return [os.path.join('s3://', bucket, key) for key in keys]
bucket,
prefix,
suffix=ext,
delimiter=delimiter,
request_payer=request_payer)
paths = [os.path.join('s3://', bucket, key) for key in keys]
return paths
9 changes: 9 additions & 0 deletions rastervision_core/rastervision/core/data/dataset_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,12 @@ def get_split_config(self, split_ind, num_splits):
@property
def all_scenes(self) -> List[SceneConfig]:
return self.train_scenes + self.validation_scenes + self.test_scenes

def __repr__(self):
num_train = len(self.train_scenes)
num_val = len(self.validation_scenes)
num_test = len(self.test_scenes)
out = (f'DatasetConfig(train_scenes=<{num_train} scenes>, '
f'validation_scenes=<{num_val} scenes>, '
f'test_scenes=<{num_test} scenes>)')
return out
Original file line number Diff line number Diff line change
Expand Up @@ -147,24 +147,20 @@ def get_chip(self,

return chip

def get_chip_by_map_window(
self,
window_map_coords: 'Box',
out_shape: Optional[Tuple[int, int]] = None) -> 'np.ndarray':
"""Same as get_chip(), but input is a window in map coords. """
def get_chip_by_map_window(self, window_map_coords: 'Box', *args,
**kwargs) -> 'np.ndarray':
"""Same as get_chip(), but input is a window in map coords."""
window_pixel_coords = self.crs_transformer.map_to_pixel(
window_map_coords, bbox=self.bbox).normalize()
chip = self.get_chip(window_pixel_coords, out_shape=out_shape)
chip = self.get_chip(window_pixel_coords, *args, **kwargs)
return chip

def _get_chip_by_map_window(
self,
window_map_coords: 'Box',
out_shape: Optional[Tuple[int, int]] = None) -> 'np.ndarray':
"""Same as _get_chip(), but input is a window in map coords. """
def _get_chip_by_map_window(self, window_map_coords: 'Box', *args,
**kwargs) -> 'np.ndarray':
"""Same as _get_chip(), but input is a window in map coords."""
window_pixel_coords = self.crs_transformer.map_to_pixel(
window_map_coords, bbox=self.bbox)
chip = self._get_chip(window_pixel_coords, out_shape=out_shape)
chip = self._get_chip(window_pixel_coords, *args, **kwargs)
return chip

def get_raw_chip(self,
Expand Down
82 changes: 50 additions & 32 deletions rastervision_pipeline/rastervision/pipeline/cli.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
from typing import TYPE_CHECKING
import sys
import os
import logging
import importlib
import importlib.util
from typing import List, Dict, Optional, Tuple

import click

from rastervision.pipeline import (registry_ as registry, rv_config_ as
rv_config)
from rastervision.pipeline.file_system import (file_to_json, get_tmp_dir)
from rastervision.pipeline.config import build_config, save_pipeline_config
from rastervision.pipeline.config import (build_config, Config,
save_pipeline_config)
from rastervision.pipeline.pipeline_config import PipelineConfig

if TYPE_CHECKING:
from rastervision.pipeline.runner import Runner

log = logging.getLogger(__name__)


Expand Down Expand Up @@ -40,8 +42,9 @@ def convert_bool_args(args: dict) -> dict:
return new_args


def get_configs(cfg_module_path: str, runner: str,
args: Dict[str, any]) -> List[PipelineConfig]:
def get_configs(cfg_module_path: str,
runner: str | None = None,
args: dict[str, any] | None = None) -> list[PipelineConfig]:
"""Get PipelineConfigs from a module.

Calls a get_config(s) function with some arguments from the CLI
Expand All @@ -55,6 +58,26 @@ def get_configs(cfg_module_path: str, runner: str,
args: CLI args to pass to the get_config(s) function that comes from
the --args option
"""
if cfg_module_path.endswith('.json'):
cfgs_json = file_to_json(cfg_module_path)
if not isinstance(cfgs_json, list):
cfgs_json = [cfgs_json]
cfgs = [Config.deserialize(json) for json in cfgs_json]
else:
cfgs = get_configs_from_module(cfg_module_path, runner, args)

for cfg in cfgs:
if not issubclass(type(cfg), PipelineConfig):
raise TypeError('All objects returned by get_configs in '
f'{cfg_module_path} must be PipelineConfigs.')
return cfgs


def get_configs_from_module(cfg_module_path: str, runner: str,
args: dict[str, any]) -> list[PipelineConfig]:
import importlib
import importlib.util

if cfg_module_path.endswith('.py'):
# From https://stackoverflow.com/questions/67631/how-to-import-a-module-given-the-full-path # noqa
spec = importlib.util.spec_from_file_location('rastervision.pipeline',
Expand All @@ -65,20 +88,14 @@ def get_configs(cfg_module_path: str, runner: str,
cfg_module = importlib.import_module(cfg_module_path)

_get_config = getattr(cfg_module, 'get_config', None)
_get_configs = _get_config
if _get_config is None:
_get_configs = getattr(cfg_module, 'get_configs', None)
_get_configs = getattr(cfg_module, 'get_configs', _get_config)
if _get_configs is None:
raise Exception('There must be a get_config or get_configs function '
f'in {cfg_module_path}.')
raise ImportError('There must be a get_config() or get_configs() '
f'function in {cfg_module_path}.')

cfgs = _get_configs(runner, **args)
if not isinstance(cfgs, list):
cfgs = [cfgs]

for cfg in cfgs:
if not issubclass(type(cfg), PipelineConfig):
raise Exception('All objects returned by get_configs in '
f'{cfg_module_path} must be PipelineConfigs.')
return cfgs


Expand All @@ -89,8 +106,7 @@ def get_configs(cfg_module_path: str, runner: str,
@click.option(
'-v', '--verbose', help='Increment the verbosity level.', count=True)
@click.option('--tmpdir', help='Root of temporary directories to use.')
def main(ctx: click.Context, profile: Optional[str], verbose: int,
tmpdir: str):
def main(ctx: click.Context, profile: str | None, verbose: int, tmpdir: str):
"""The main click command.

Sets the profile, verbosity, and tmp_dir in RVConfig.
Expand All @@ -103,20 +119,22 @@ def main(ctx: click.Context, profile: Optional[str], verbose: int,
rv_config.set_everett_config(profile=profile)


def _run_pipeline(cfg,
runner,
tmp_dir,
splits=1,
commands=None,
def _run_pipeline(cfg: PipelineConfig,
runner: 'Runner',
tmp_dir: str,
splits: int = 1,
commands: list[str] | None = None,
pipeline_run_name: str = 'raster-vision'):
cfg.update()
cfg.recursive_validate_config()
# This is to run the validation again to check any fields that may have changed
# after the Config was constructed, possibly by the update method.

# This is to run the validation again to check any fields that may have
# changed after the Config was constructed, possibly by the update method.
build_config(cfg.dict())
cfg_json_uri = cfg.get_config_uri()
save_pipeline_config(cfg, cfg_json_uri)
pipeline = cfg.build(tmp_dir)

if not commands:
commands = pipeline.commands

Expand Down Expand Up @@ -150,8 +168,8 @@ def _run_pipeline(cfg,
'--pipeline-run-name',
default='raster-vision',
help='The name for this run of the pipeline.')
def run(runner: str, cfg_module: str, commands: List[str],
arg: List[Tuple[str, str]], splits: int, pipeline_run_name: str):
def run(runner: str, cfg_module: str, commands: list[str],
arg: list[tuple[str, str]], splits: int, pipeline_run_name: str):
"""Run COMMANDS within pipelines in CFG_MODULE using RUNNER.

RUNNER: name of the Runner to use
Expand All @@ -178,9 +196,9 @@ def run(runner: str, cfg_module: str, commands: List[str],

def _run_command(cfg_json_uri: str,
command: str,
split_ind: Optional[int] = None,
num_splits: Optional[int] = None,
runner: Optional[str] = None):
split_ind: int | None = None,
num_splits: int | None = None,
runner: str | None = None):
"""Run a single command using a serialized PipelineConfig.

Args:
Expand Down Expand Up @@ -229,8 +247,8 @@ def _run_command(cfg_json_uri: str,
help='The number of processes to use for running splittable commands')
@click.option(
'--runner', type=str, help='Name of runner to use', default='inprocess')
def run_command(cfg_json_uri: str, command: str, split_ind: Optional[int],
num_splits: Optional[int], runner: str):
def run_command(cfg_json_uri: str, command: str, split_ind: int | None,
num_splits: int | None, runner: str):
"""Run a single COMMAND using a serialized PipelineConfig in CFG_JSON_URI."""
_run_command(
cfg_json_uri,
Expand Down
Loading
Loading