Skip to content

Commit

Permalink
Review comments#1
Browse files Browse the repository at this point in the history
  • Loading branch information
ashnair1 committed Aug 12, 2022
1 parent 1f9dd67 commit 71f028a
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 20 deletions.
26 changes: 10 additions & 16 deletions torchgeo/datasets/geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def __init__(
root: str,
crs: Optional[CRS] = None,
res: Optional[float] = None,
bands: Sequence[str] = [],
bands: Optional[Sequence[str]] = None,
transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
cache: bool = True,
) -> None:
Expand All @@ -322,7 +322,7 @@ def __init__(
(defaults to the CRS of the first file found)
res: resolution of the dataset in units of CRS
(defaults to the resolution of the first file found)
bands: list of band names to be used
bands: bands to return (defaults to all bands)
transforms: a function/transform that takes an input sample
and returns a transformed version
cache: if True, cache file handle to speed up repeated sampling
Expand All @@ -334,7 +334,6 @@ def __init__(

self.root = root
self.cache = cache
self.bands = bands

# Populate the dataset index
i = 0
Expand Down Expand Up @@ -378,18 +377,13 @@ def __init__(
f"No {self.__class__.__name__} data was found in '{root}'"
)

if not self.all_bands:
band_indexes = None
if bands and self.all_bands:
band_indexes = [self.all_bands.index(i) + 1 for i in bands]
self.bands = bands
assert len(band_indexes) == len(self.bands)
else:
if self.bands:
band_indexes = [self.all_bands.index(i) + 1 for i in self.bands]
assert len(band_indexes) == len(self.bands)
else:
band_indexes = None

if self.rgb_bands:
rgb_band_indexes = [self.all_bands.index(i) + 1 for i in self.rgb_bands]
assert len(rgb_band_indexes) == len(self.rgb_bands)
band_indexes = None
self.bands = self.all_bands

self.band_indexes = band_indexes
self._crs = cast(CRS, crs)
Expand Down Expand Up @@ -418,7 +412,7 @@ def __getitem__(self, query: BoundingBox) -> Dict[str, Any]:
if self.separate_files:
data_list: List[Tensor] = []
filename_regex = re.compile(self.filename_regex, re.VERBOSE)
for band in getattr(self, "bands", self.all_bands):
for band in self.bands:
band_filepaths = []
for filepath in filepaths:
filename = os.path.basename(filepath)
Expand All @@ -435,7 +429,7 @@ def __getitem__(self, query: BoundingBox) -> Dict[str, Any]:
filename = filename[:start] + "*" + filename[end:]
filepath = glob.glob(os.path.join(directory, filename))[0]
band_filepaths.append(filepath)
data_list.append(self._merge_files(band_filepaths, query, [1]))
data_list.append(self._merge_files(band_filepaths, query))
data = torch.cat(data_list)
else:
data = self._merge_files(filepaths, query, self.band_indexes)
Expand Down
1 change: 0 additions & 1 deletion torchgeo/datasets/landsat.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ def __init__(
Raises:
FileNotFoundError: if no files are found in ``root``
"""
bands = bands or self.all_bands
self.filename_glob = self.filename_glob.format(bands[0])

super().__init__(root, crs, res, bands, transforms, cache)
Expand Down
4 changes: 1 addition & 3 deletions torchgeo/datasets/sentinel.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __init__(
root: str = "data",
crs: Optional[CRS] = None,
res: Optional[float] = None,
bands: Sequence[str] = [],
bands: Optional[Sequence[str]] = None,
transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
cache: bool = True,
) -> None:
Expand All @@ -97,8 +97,6 @@ def __init__(
Raises:
FileNotFoundError: if no files are found in ``root``
"""
bands = bands or self.all_bands

super().__init__(root, crs, res, bands, transforms, cache)

def plot(
Expand Down

0 comments on commit 71f028a

Please sign in to comment.