Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix typing #25

Merged
merged 3 commits into from
Jul 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@

repos:
- repo: https://github.com/psf/black
rev: 22.12.0
rev: 23.7.0
hooks:
- id: black
- repo: https://github.com/codespell-project/codespell
rev: v2.2.2
rev: v2.2.5
hooks:
- id: codespell
args: [--ignore-words=.codespellignore]
Expand All @@ -26,7 +26,7 @@ repos:
- id: isort
args: ["--profile", "black"]
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.991
rev: v1.4.1
hooks:
- id: mypy
additional_dependencies:
Expand Down
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
include stactask/py.typed
10 changes: 10 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
[tool.mypy]
strict = true

[[tool.mypy.overrides]]
module = [
"boto3utils",
"jsonpath_ng.ext",
"fsspec",
]
ignore_missing_imports = true
50 changes: 31 additions & 19 deletions stactask/asset_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
import logging
import os
from os import path as op
from typing import Dict, List, Optional
from typing import Any, Dict, List, Optional, Union

import fsspec
from boto3utils import s3
from fsspec import AbstractFileSystem
from pystac import Item
from pystac.layout import LayoutTemplate

Expand All @@ -18,7 +19,7 @@
sem = asyncio.Semaphore(SIMULTANEOUS_DOWNLOADS)


async def download_file(fs, src, dest):
async def download_file(fs: AbstractFileSystem, src: str, dest: str) -> None:
async with sem:
logger.debug(f"{src} start")
await fs._get_file(src, dest)
Expand All @@ -32,9 +33,8 @@ async def download_item_assets(
overwrite: bool = False,
path_template: str = "${collection}/${id}",
absolute_path: bool = False,
**kwargs,
):

**kwargs: Any,
) -> Item:
_assets = item.assets.keys() if assets is None else assets

# determine path from template and item
Expand Down Expand Up @@ -77,44 +77,56 @@ async def download_item_assets(
return new_item


async def download_items_assets(items, **kwargs):
async def download_items_assets(items: List[Item], **kwargs: Any) -> List[Item]:
tasks = []
for item in items:
tasks.append(asyncio.create_task(download_item_assets(item, **kwargs)))
new_items = await asyncio.gather(*tasks)
new_items: List[Item] = await asyncio.gather(*tasks)
return new_items


def upload_item_assets_to_s3(
item: Item,
assets: Optional[List[str]] = None,
public_assets: List[str] = [],
public_assets: Union[None, List[str], str] = None,
path_template: str = "${collection}/${id}",
s3_urls: bool = False,
headers: Dict = {},
**kwargs,
) -> Dict:
headers: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Item:
"""Upload Item assets to s3 bucket
Args:
item (Dict): STAC Item
assets (List[str], optional): List of asset keys to upload. Defaults to None.
public_assets (List[str], optional): List of assets keys that should be public. Defaults to [].
path_template (str, optional): Path string template. Defaults to '${collection}/${id}'.
s3_urls (bool, optional): Return s3 URLs instead of http URLs. Defaults to False.
headers (Dict, optional): Dictionary of headers to set on uploaded assets. Defaults to {},
public_assets (List[str], optional): List of assets keys that should be
public. Defaults to [].
path_template (str, optional): Path string template. Defaults to
'${collection}/${id}'.
s3_urls (bool, optional): Return s3 URLs instead of http URLs. Defaults
to False.
headers (Dict, optional): Dictionary of headers to set on uploaded
assets. Defaults to {},
Returns:
Dict: A new STAC Item with uploaded assets pointing to newly uploaded file URLs
"""
if headers is None:
headers = {}

# deepcopy of item
_item = item.to_dict()

if public_assets is None:
public_assets = []
# determine which assets should be public
elif type(public_assets) is str:
if public_assets == "ALL":
public_assets = list(_item["assets"].keys())
else:
raise ValueError(f"unexpected value for `public_assets`: {public_assets}")

# if assets not provided, upload all assets
_assets = assets if assets is not None else _item["assets"].keys()

# determine which assets should be public
if type(public_assets) is str and public_assets == "ALL":
public_assets = _item["assets"].keys()

for key in [a for a in _assets if a in _item["assets"].keys()]:
asset = _item["assets"][key]
filename = asset["href"]
Expand Down
Empty file added stactask/py.typed
Empty file.
74 changes: 42 additions & 32 deletions stactask/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
from pathlib import Path
from shutil import rmtree
from tempfile import mkdtemp
from typing import Any, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union

import fsspec
from pystac import ItemCollection
from pystac import Item, ItemCollection

from .asset_io import (
download_item_assets,
Expand All @@ -27,8 +27,9 @@
# types
PathLike = Union[str, Path]
"""
Tasks can use parameters provided in a `process` Dictionary that is supplied in the ItemCollection
JSON under the "process" field. An example process definition:
Tasks can use parameters provided in a `process` Dictionary that is supplied in
the ItemCollection JSON under the "process" field. An example process
definition:

```
{
Expand All @@ -53,14 +54,13 @@


class Task(ABC):

name = "task"
description = "A task for doing things"
version = "0.1.0"

def __init__(
self: "Task",
payload: Dict,
payload: Dict[str, Any],
workdir: Optional[PathLike] = None,
save_workdir: bool = False,
skip_upload: bool = False,
Expand Down Expand Up @@ -89,18 +89,18 @@ def __init__(
self._workdir = Path(workdir)
makedirs(self._workdir, exist_ok=True)

def __del__(self):
def __del__(self) -> None:
# remove work directory if not running locally
if not self._save_workdir:
self.logger.debug("Removing work directory %s", self._workdir)
rmtree(self._workdir)

@property
def process_definition(self) -> Dict:
def process_definition(self) -> Dict[str, Any]:
return self._payload.get("process", {})

@property
def parameters(self) -> Dict:
def parameters(self) -> Dict[str, Any]:
task_configs = self.process_definition.get("tasks", [])
if isinstance(task_configs, List):
# tasks is a list
Expand All @@ -122,11 +122,11 @@ def parameters(self) -> Dict:
raise ValueError(f"unexpected value for 'tasks': {task_configs}")

@property
def upload_options(self) -> Dict:
def upload_options(self) -> Dict[str, Any]:
return self.process_definition.get("upload_options", {})

@property
def items_as_dicts(self) -> List[Dict]:
def items_as_dicts(self) -> List[Dict[str, Any]]:
return self._payload.get("features", [])

@property
Expand All @@ -135,12 +135,12 @@ def items(self) -> ItemCollection:
return ItemCollection.from_dict(items_dict, preserve_dict=True)

@classmethod
def validate(cls, payload: Dict) -> bool:
def validate(cls, payload: Dict[str, Any]) -> bool:
# put validation logic on input Items and process definition here
return True

@classmethod
def add_software_version(cls, items: List[Dict]):
def add_software_version(cls, items: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
processing_ext = (
"https://stac-extensions.github.io/processing/v1.1.0/schema.json"
)
Expand All @@ -154,7 +154,7 @@ def add_software_version(cls, items: List[Dict]):
i["properties"]["processing:software"] = {cls.name: cls.version}
return items

def assign_collections(self):
def assign_collections(self) -> None:
"""Assigns new collection names based on"""
for i, (coll, expr) in itertools.product(
self._payload["features"],
Expand All @@ -164,13 +164,18 @@ def assign_collections(self):
i["collection"] = coll

def download_item_assets(
self, item: Dict, path_template: str = "${collection}/${id}", **kwargs
):
"""Download provided asset keys for all items in payload. Assets are saved in workdir in a
directory named by the Item ID, and the items are updated with the new asset hrefs.
self,
item: Item,
path_template: str = "${collection}/${id}",
**kwargs: Any,
) -> Item:
"""Download provided asset keys for all items in payload. Assets are
saved in workdir in a directory named by the Item ID, and the items are
updated with the new asset hrefs.

Args:
assets (Optional[List[str]], optional): List of asset keys to download. Defaults to all assets.
assets (Optional[List[str]], optional): List of asset keys to
download. Defaults to all assets.
"""
outdir = str(self._workdir / path_template)
loop = asyncio.get_event_loop()
Expand All @@ -180,16 +185,21 @@ def download_item_assets(
return item

def download_items_assets(
self, items: List[Dict], path_template: str = "${collection}/${id}", **kwargs
):
self,
items: List[Item],
path_template: str = "${collection}/${id}",
**kwargs: Any,
) -> List[Item]:
outdir = str(self._workdir / path_template)
loop = asyncio.get_event_loop()
items = loop.run_until_complete(
download_items_assets(self.items, path_template=outdir, **kwargs)
download_items_assets(items, path_template=outdir, **kwargs)
)
return items

def upload_item_assets_to_s3(self, item: Dict, assets: Optional[List[str]] = None):
def upload_item_assets_to_s3(
self, item: Item, assets: Optional[List[str]] = None
) -> Item:
if self._skip_upload:
self.logger.warning("Skipping upload of new and modified assets")
return item
Expand All @@ -198,7 +208,7 @@ def upload_item_assets_to_s3(self, item: Dict, assets: Optional[List[str]] = Non

# this should be in PySTAC
@staticmethod
def create_item_from_item(item):
def create_item_from_item(item: Dict[str, Any]) -> Dict[str, Any]:
new_item = deepcopy(item)
# create a derived output item
links = [
Expand All @@ -217,7 +227,7 @@ def create_item_from_item(item):
return new_item

@abstractmethod
def process(self, **kwargs) -> List[Dict]:
def process(self, **kwargs: Any) -> List[Dict[str, Any]]:
"""Main task logic - virtual

Returns:
Expand All @@ -230,7 +240,7 @@ def process(self, **kwargs) -> List[Dict]:
pass

@classmethod
def handler(cls, payload: Dict, **kwargs) -> Dict[str, Any]:
def handler(cls, payload: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]:
if "href" in payload or "url" in payload:
# read input
with fsspec.open(payload.get("href", payload.get("url"))) as f:
Expand All @@ -250,7 +260,7 @@ def handler(cls, payload: Dict, **kwargs) -> Dict[str, Any]:
raise err

@classmethod
def parse_args(cls, args):
def parse_args(cls, args: List[str]) -> Dict[str, Any]:
dhf = argparse.ArgumentDefaultsHelpFormatter
parser0 = argparse.ArgumentParser(description=cls.description)
parser0.add_argument(
Expand Down Expand Up @@ -298,8 +308,8 @@ def parse_args(cls, args):
default=False,
)
h = """ Run local mode
(save-workdir = True, skip-upload = True, skip-validation = True,
workdir = 'local-output', output = 'local-output/output-payload.json') """
(save-workdir = True, skip-upload = True, skip-validation = True,
workdir = 'local-output', output = 'local-output/output-payload.json') """
parser.add_argument("--local", help=h, action="store_true", default=False)

# turn Namespace into dictionary
Expand All @@ -323,7 +333,7 @@ def parse_args(cls, args):
return pargs

@classmethod
def cli(cls):
def cli(cls) -> None:
args = cls.parse_args(sys.argv[1:])
cmd = args.pop("command")

Expand Down Expand Up @@ -365,9 +375,9 @@ def cli(cls):
from functools import wraps # noqa


def silence_event_loop_closed(func):
def silence_event_loop_closed(func: Callable[[Any], Any]) -> Callable[[Any], Any]:
@wraps(func)
def wrapper(self, *args, **kwargs):
def wrapper(self, *args: Any, **kwargs: Any) -> Any: # type: ignore
try:
return func(self, *args, **kwargs)
except RuntimeError as e:
Expand Down
4 changes: 2 additions & 2 deletions stactask/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Dict
from typing import Any, Dict

from jsonpath_ng.ext import parser


def stac_jsonpath_match(item: Dict, expr: str) -> bool:
def stac_jsonpath_match(item: Dict[str, Any], expr: str) -> bool:
"""Match jsonpath expression against STAC JSON.
Use https://jsonpath.herokuapp.com/ to experiment with JSONpath
and https://regex101.com/ to experiment with regex
Expand Down
1 change: 0 additions & 1 deletion tests/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ def process(self):


class FailValidateTask(Task):

name = "failvalidation-task"
description = "this task always fails validation"

Expand Down