From 5a9af2b59b4ab177d84e60713369e3b3a6b27fe2 Mon Sep 17 00:00:00 2001 From: Adeel Hassan Date: Tue, 5 Sep 2023 15:48:55 -0400 Subject: [PATCH 1/3] minor tutorial code improvements [skip ci] --- docs/usage/tutorials/stac_plus_osm.ipynb | 145 +++++++------ docs/usage/tutorials/temporal.ipynb | 246 +++++++++++------------ 2 files changed, 184 insertions(+), 207 deletions(-) diff --git a/docs/usage/tutorials/stac_plus_osm.ipynb b/docs/usage/tutorials/stac_plus_osm.ipynb index b99577f8f..58b11fb45 100644 --- a/docs/usage/tutorials/stac_plus_osm.ipynb +++ b/docs/usage/tutorials/stac_plus_osm.ipynb @@ -49,7 +49,7 @@ }, "outputs": [], "source": [ - "!pip install osm2geojson==0.2.4 pystac_client==0.6.1 stackstac==0.4.4" + "%pip install osm2geojson==0.2.4 pystac_client==0.6.1 stackstac==0.4.4" ] }, { @@ -62,18 +62,18 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "id": "f72cd9be-2f52-4cd7-acb8-58b8dd9047c0", "metadata": {}, "outputs": [], "source": [ - "from rastervision.pipeline.file_system.utils import json_to_file, file_to_json\n", + "from rastervision.pipeline.file_system.utils import json_to_file\n", "from rastervision.core.box import Box\n", - "from rastervision.core.data import (\n", - " RasterioCRSTransformer, StatsTransformer, XarraySource)\n", + "from rastervision.core.data import (RasterioCRSTransformer, StatsTransformer,\n", + " XarraySource)\n", "from rastervision.core.data.raster_source import XarraySource\n", - "from rastervision.core.data.utils import (\n", - " geoms_to_geojson, geojson_to_geoms, get_polygons_from_uris)\n", + "from rastervision.core.data.utils import (geoms_to_geojson, geojson_to_geoms,\n", + " get_polygons_from_uris)\n", "\n", "from shapely.geometry import mapping\n", "from matplotlib import pyplot as plt\n", @@ -104,12 +104,16 @@ "\n", "overpass_api_endpoint = 'https://overpass-api.de/api/interpreter'\n", "\n", + "\n", "def fetch_osm_geojson(query: str) -> dict:\n", - " response = requests.get(f'{overpass_api_endpoint}?data=[out:xml][timeout:25];{query};out geom;')\n", - " string = response.text.replace(\"\\n\",\"\")\n", + " response = requests.get(\n", + " f'{overpass_api_endpoint}?data=[out:xml][timeout:25];{query};out geom;'\n", + " )\n", + " string = response.text.replace(\"\\n\", \"\")\n", " geojson = osm2geojson.xml2geojson(string)\n", " return geojson\n", "\n", + "\n", "def get_city_boundary(country: str, city: str) -> Box:\n", " query = f\"\"\"\n", " area[name=\"{country}\"][admin_level=2]->.country;\n", @@ -123,11 +127,13 @@ " geom = unary_union(list(geojson_to_geoms(geojson)))\n", " return geom\n", "\n", + "\n", "def get_city_bbox(country: str, city: str) -> Box:\n", " geom = get_city_boundary(country, city)\n", " city_bbox = Box.from_shapely(geom.envelope)\n", " return city_bbox\n", "\n", + "\n", "def get_city_features(country: str, city: str, query: str) -> dict:\n", " query_city = f\"\"\"\n", " area[name=\"{country}\"][admin_level=2]->.country;\n", @@ -210,7 +216,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "id": "ac3c0b96-7680-4db8-b239-c85c37c2bf7b", "metadata": {}, "outputs": [], @@ -229,7 +235,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "id": "07f26c38-8d92-4393-85d0-d3f20fa36c60", "metadata": {}, "outputs": [ @@ -245,7 +251,7 @@ " (2.224122, 48.8155755)),)})" ] }, - "execution_count": 7, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -267,18 +273,10 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 12, "id": "86306996-4906-47ae-961b-dbf0bfd66a58", "metadata": {}, "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "CPU times: user 57.1 ms, sys: 16.2 ms, total: 73.3 ms\n", - "Wall time: 1.37 s\n" - ] - }, { "data": { "text/html": [ @@ -645,7 +643,7 @@ " stroke: currentColor;\n", " fill: currentColor;\n", "}\n", - "
<xarray.DataArray 'stackstac-e276df0c3cc905d489a335071a8f869d' (time: 1,\n",
+       "
<xarray.DataArray 'stackstac-c221d8123b0a1e61acf10801cf3e5756' (time: 1,\n",
        "                                                                band: 32,\n",
        "                                                                y: 10980,\n",
        "                                                                x: 10980)>\n",
@@ -656,7 +654,7 @@
        "  * band                                     (band) <U12 'aot' ... 'wvp-jp2'\n",
        "  * x                                        (x) float64 4e+05 ... 5.098e+05\n",
        "  * y                                        (y) float64 5.5e+06 ... 5.39e+06\n",
-       "    s2:datatake_type                         <U8 'INS-NOBS'\n",
+       "    view:sun_azimuth                         float64 163.6\n",
        "    ...                                       ...\n",
        "    raster:bands                             (band) object [{'nodata': 0, 'da...\n",
        "    gsd                                      (band) object None 10 ... None None\n",
@@ -668,7 +666,7 @@
        "    spec:        RasterSpec(epsg=32631, bounds=(399960.0, 5390220.0, 509760.0...\n",
        "    crs:         epsg:32631\n",
        "    transform:   | 10.00, 0.00, 399960.00|\\n| 0.00,-10.00, 5500020.00|\\n| 0.0...\n",
-       "    resolution:  10.0
  • spec :
    RasterSpec(epsg=32631, bounds=(399960.0, 5390220.0, 509760.0, 5500020.0), resolutions_xy=(10.0, 10.0))
    crs :
    epsg:32631
    transform :
    | 10.00, 0.00, 399960.00|\n", "| 0.00,-10.00, 5500020.00|\n", "| 0.00, 0.00, 1.00|
    resolution :
    10.0
  • " ], "text/plain": [ - "\n", @@ -965,7 +963,7 @@ " * band (band)
    <xarray.DataArray 'stackstac-f40405ced835078c43a6fab9878fedf4' (time: 6,\n",
    +       "
    <xarray.DataArray 'stackstac-324721a98d3cff22ae730b7f1736dac4' (time: 6,\n",
            "                                                                band: 32,\n",
            "                                                                y: 10980,\n",
            "                                                                x: 10980)>\n",
    @@ -519,9 +519,9 @@
            "  * band                                     (band) <U12 'aot' ... 'wvp-jp2'\n",
            "  * x                                        (x) float64 4e+05 ... 5.098e+05\n",
            "  * y                                        (y) float64 5.5e+06 ... 5.39e+06\n",
    -       "    s2:degraded_msi_data_percentage          (time) float64 0.0322 ... 0.0308\n",
    +       "    mgrs:grid_square                         <U2 'DQ'\n",
            "    ...                                       ...\n",
    -       "    title                                    (band) <U31 'Aerosol optical thi...\n",
    +       "    raster:bands                             (band) object [{'nodata': 0, 'da...\n",
            "    gsd                                      (band) object None 10 ... None None\n",
            "    common_name                              (band) object None 'blue' ... None\n",
            "    center_wavelength                        (band) object None 0.49 ... None\n",
    @@ -531,7 +531,7 @@
            "    spec:        RasterSpec(epsg=32631, bounds=(399960.0, 5390220.0, 509760.0...\n",
            "    crs:         epsg:32631\n",
            "    transform:   | 10.00, 0.00, 399960.00|\\n| 0.00,-10.00, 5500020.00|\\n| 0.0...\n",
    -       "    resolution:  10.0
  • spec :
    RasterSpec(epsg=32631, bounds=(399960.0, 5390220.0, 509760.0, 5500020.0), resolutions_xy=(10.0, 10.0))
    crs :
    epsg:32631
    transform :
    | 10.00, 0.00, 399960.00|\n", "| 0.00,-10.00, 5500020.00|\n", "| 0.00, 0.00, 1.00|
    resolution :
    10.0
  • " ], "text/plain": [ - "\n", @@ -892,9 +892,9 @@ " * band (band) Date: Wed, 6 Sep 2023 11:20:03 -0400 Subject: [PATCH 2/3] rephrase notebook header note about env variables --- docs/conf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/conf.py b/docs/conf.py index 76dcd9f21..aac152162 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -154,7 +154,7 @@ def setup(app: 'Sphinx') -> None: .. note:: - If running outside of the Docker image, you might need to set a couple of environment variables manually. You can do it like so: + If running outside of the Docker image, you may need to set some environment variables manually. You can do it like so: .. code-block:: python From f3ceb3aa1b2b09e091d1108d79bcade905be9e45 Mon Sep 17 00:00:00 2001 From: Adeel Hassan Date: Wed, 6 Sep 2023 12:37:12 -0400 Subject: [PATCH 3/3] misc. docstring and type hint fixes [skip ci] --- .../pytorch_learner/learner_config.py | 56 +++++++++++-------- 1 file changed, 33 insertions(+), 23 deletions(-) diff --git a/rastervision_pytorch_learner/rastervision/pytorch_learner/learner_config.py b/rastervision_pytorch_learner/rastervision/pytorch_learner/learner_config.py index 2c16e66de..30d4cb5eb 100644 --- a/rastervision_pytorch_learner/rastervision/pytorch_learner/learner_config.py +++ b/rastervision_pytorch_learner/rastervision/pytorch_learner/learner_config.py @@ -177,7 +177,7 @@ def build(self, save_dir: str, hubconf_dir: Optional[str] = None) -> Any: external source but instead from this dir. Defaults to None. Returns: - Any: The module loaded via torch.hub. + The module loaded via torch.hub. """ if hubconf_dir is not None: log.info(f'Using existing module definition at: {hubconf_dir}') @@ -267,7 +267,7 @@ def build(self, **kwargs: Extra args for :meth:`.build_default_model`. Returns: - nn.Module: a PyTorch nn.Module. + A PyTorch nn.Module. """ if self.external_def is not None: return self.build_external_model( @@ -284,7 +284,7 @@ def build_default_model(self, num_classes: int, in_channels: int, will be fed into the model. Defaults to 3. Returns: - nn.Module: a PyTorch nn.Module. + A PyTorch nn.Module. """ raise NotImplementedError() @@ -299,7 +299,7 @@ def build_external_model(self, Defaults to None. Returns: - nn.Module: a PyTorch nn.Module. + A PyTorch nn.Module. """ return self.external_def.build(save_dir, hubconf_dir=hubconf_dir) @@ -382,7 +382,7 @@ def build_loss(self, external_loss_def if specified. Defaults to None. Returns: - Callable: Loss function. + Loss function. """ if self.external_loss_def is not None: return self.external_loss_def.build( @@ -414,7 +414,7 @@ def build_optimizer(self, model: nn.Module, **kwargs) -> optim.Adam: **kwargs: Extra args for the optimizer constructor. Returns: - optim.Optimizer: An Adam optimzer instance. + An Adam optimzer instance. """ return optim.Adam(model.parameters(), lr=self.lr, **kwargs) @@ -435,8 +435,7 @@ def build_step_scheduler(self, **kwargs: Extra args for the scheduler constructor. Returns: - Optional[_LRScheduler]: A step scheduler, if applicable. Otherwise, - None. + A step scheduler, if applicable. Otherwise, None. """ scheduler = None if self.one_cycle and self.num_epochs > 1: @@ -475,8 +474,7 @@ def build_epoch_scheduler(self, **kwargs: Extra args for the scheduler constructor. Returns: - Optional[_LRScheduler]: An epoch scheduler, if applicable. Otherwise, - None. + An epoch scheduler, if applicable. Otherwise, None. """ scheduler = None if self.multi_stage: @@ -727,9 +725,19 @@ def get_bbox_params(self) -> Optional[A.BboxParams]: def get_data_transforms(self) -> Tuple[A.BasicTransform, A.BasicTransform]: """Get albumentations transform objects for data augmentation. + Returns a 2-tuple of a "base" transform and an augmentation transform. + The base transform comprises a resize transform based on img_sz + followed by the transform specified in base_transform. The augmentation + transform comprises the base transform followed by either the transform + in aug_transform (if specified) or the transforms in the augmentors + field. + + The augmentation transform is intended to be used for training data, + and the base transform for all other data where data augmentation is + not desirable, such as validation or prediction. + Returns: - 1st tuple arg: a transform that doesn't do any data augmentation - 2nd tuple arg: a transform with data augmentation + base transform and augmentation transform. """ bbox_params = self.get_bbox_params() base_tfs = [A.Resize(self.img_sz, self.img_sz)] @@ -876,8 +884,7 @@ def make_datasets(self, test dataset. Defaults to None. Returns: - Tuple[Dataset, Dataset, Dataset]: PyTorch-compatiable training, - validation, and test datasets. + PyTorch-compatiable training, validation, and test datasets. """ train_ds_list = [self.dir_to_dataset(d, train_tf) for d in train_dirs] val_ds_list = [self.dir_to_dataset(d, val_tf) for d in val_dirs] @@ -938,8 +945,7 @@ def get_datasets_from_uri( images. Returns: - Tuple[Dataset, Dataset, Dataset]: Training, validation, and test - dataSets. + Training, validation, and test dataSets. """ data_dirs = self.get_data_dirs(uri, unzip_dir=tmp_dir) @@ -1013,7 +1019,7 @@ def get_data_dirs(self, uri: Union[str, List[str]], (optinally) "test" subdirectories. Args: - uri (Union[str, List[str]]): a URI or a list of URIs of one of the + uri (Union[str, List[str]]): A URI or a list of URIs of one of the following: (1) a URI of a directory containing "train", "valid", and @@ -1021,9 +1027,12 @@ def get_data_dirs(self, uri: Union[str, List[str]], (2) a URI of a zip file containing (1) (3) a list of (2) (4) a URI of a directory containing zip files containing (1) + unzip_dir (str): Directory where zip files will be extrated to, if + needed. Returns: - paths to directories that each contain contents of one zip file + List[str]: Paths to directories that each contain contents of one + zip file. """ def is_data_dir(uri: str) -> bool: @@ -1066,11 +1075,12 @@ def unzip_data(self, zip_uris: List[str], unzip_dir: str) -> List[str]: """Unzip dataset zip files. Args: - zip_uris (List[str]): a list of URIs of zip files: - unzip_dir (str): directory where zip files will be extrated to. + zip_uris (List[str]): A list of URIs of zip files: + unzip_dir (str): Directory where zip files will be extrated to. Returns: - paths to directories that each contain contents of one zip file + List[str]: Paths to directories that each contain contents of one + zip file. """ data_dirs = [] @@ -1275,7 +1285,7 @@ def make_datasets(self, Returns: Tuple[Dataset, Dataset, Dataset]: PyTorch-compatiable training, - validation, and test datasets. + validation, and test datasets. """ train_scenes, val_scenes, test_scenes = self.build_scenes(tmp_dir) @@ -1413,7 +1423,7 @@ def build(self, model_weights_path: Optional[str] = None, model_def_path: Optional[str] = None, loss_def_path: Optional[str] = None, - training=True) -> 'Learner': + training: bool = True) -> 'Learner': """Returns a Learner instantiated using this Config. Args: