Skip to content

Commit

Permalink
Update benchmark script
Browse files Browse the repository at this point in the history
  • Loading branch information
adamjstewart committed Feb 22, 2022
1 parent 518bd1e commit cb72d0a
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 13 deletions.
17 changes: 7 additions & 10 deletions benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,16 @@ def set_up_parser() -> argparse.ArgumentParser:
"--patch-size",
default=224,
type=int,
help="height/width of each patch",
metavar="SIZE",
help="height/width of each patch in pixels",
metavar="PIXELS",
)
parser.add_argument(
"-s",
"--stride",
default=112,
type=int,
help="sampling stride for GridGeoSampler",
help="sampling stride for GridGeoSampler in pixels",
metavar="PIXELS",
)
parser.add_argument(
"-w",
Expand Down Expand Up @@ -139,15 +140,11 @@ def main(args: argparse.Namespace) -> None:
length = args.num_batches * args.batch_size
num_batches = args.num_batches

# Convert from pixel coords to CRS coords
size = args.patch_size * cdl.res
stride = args.stride * cdl.res

samplers = [
RandomGeoSampler(landsat, size=size, length=length),
GridGeoSampler(landsat, size=size, stride=stride),
RandomGeoSampler(landsat, size=args.patch_size, length=length),
GridGeoSampler(landsat, size=args.patch_size, stride=args.stride),
RandomBatchGeoSampler(
landsat, size=size, batch_size=args.batch_size, length=length
landsat, size=args.patch_size, batch_size=args.batch_size, length=length
),
]

Expand Down
4 changes: 2 additions & 2 deletions torchgeo/samplers/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
class Units(Enum):
"""Enumeration defining units of ``size`` parameter.
Used by :class:`~torchgeo.sampler.GeoSampler` and
:class:`~torchgeo.sampler.BatchGeoSampler`.
Used by :class:`~torchgeo.samplers.GeoSampler` and
:class:`~torchgeo.samplers.BatchGeoSampler`.
"""

PIXELS = auto()
Expand Down
1 change: 0 additions & 1 deletion torchgeo/samplers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from typing import Tuple, Union

from ..datasets import BoundingBox
from .constants import Units


def _to_tuple(value: Union[Tuple[float, float], float]) -> Tuple[float, float]:
Expand Down

0 comments on commit cb72d0a

Please sign in to comment.