-
Notifications
You must be signed in to change notification settings - Fork 383
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Concerns with GridGeoSampler for evaluation #1245
Comments
I'm fine with this, and can take it. |
@calebrob6 any updates on this? |
With some delay, no trouble with this on my end. Thanks for asking. |
Hi I know this is way after the fact, but when I use "ignore_index=0" inside the PixelClassificationModel() class PixelClassificationModel(pl.LightningModule):
def __init__(self, num_classes, in_channels, learning_rate=0.001):
super(PixelClassificationModel, self).__init__()
self.learning_rate = learning_rate
# Example of using a larger model
self.model = models.segmentation.deeplabv3_resnet50(weights=None, num_classes=num_classes)
# self.model = models.segmentation.deeplabv3_mobilenet_v3_large(weights=None, num_classes=num_classes)
self.model.backbone.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.criterion = nn.CrossEntropyLoss(ignore_index=0) I have the following issue:
I study habitats along rivers, so in any given raster ~95% of the data area areas that I don't want to classify (e.g. urban areas, upland habitat). I've set these areas with a label of 0 as I want to mask them. The ignore index of 0 is not working as expected so I've tried other ways to omit these tiles from the model. I've spent a lot time trying to filter out these images I don't care from the batches but I can't seem to get it to work using a custom filter function for that dataloader, nor ignoring these images in the trainer and validator, each of these options creates a model that has very odd predictions. I can share code and some sample .tifs for labels and images if that would help, but I'm sort of at a loss as to how to get around this. |
Hi @FogDrip, it doesn't seem like you're using TorchGeo in your code snippet. Can you try using |
Hi @adamjstewart thank you so much for your quick response, I'm eager to learn and use TorchGeo as I'm impressed by it's customizability relative to ArcGIS DL. I switched to using SemanticSegementationTask: task = SemanticSegmentationTask(
model="deeplabv3+",
backbone="resnet50",
in_channels=num_input_channels,
num_classes=num_classes,
ignore_index=0,
lr=learning_rate,
loss="ce"
) I'm now running into issues around:
I'm not sure if this relates to a custom label class I created for my labels that inherits from Chesapeake. I'm using the default NAIP class for the NAIP imagery in my study area so I'm less concerned there. class ChesapeakeCA(Chesapeake):
"""Custom Chesapeake dataset class for California."""
base_folder = "ARP"
filename = "label_clipped.tif"
filename_glob = filename
... To solve this frozen problem, I've trying using a custom collate function to filter out the bounding box and CRS, but that gets me into a bit of a rabbit hole with having to deal with mismatch between tensor shapes, where that solution causes other problems. |
I get this error when training models with lightning and |
Hi @calebrob6 thanks for the quick reply. I saw your colleagues presentation at ESA on fence classification using DeepLabs, which was excellent. I'm able to get rid of the bounds with a custom collate function, where I tried to initially do it within the custom class ChesapeakeCA(Chesapeake) using a def getitem: del sample['bbox'] but it wasn't removing them permanently. def custom_collate_fn(batch):
"""Custom collate function to remove 'bbox', 'bounds', and 'crs' fields before batching."""
for sample in batch:
if 'bbox' in sample:
del sample['bbox']
if 'bounds' in sample:
del sample['bounds']
if 'crs' in sample:
del sample['crs'] # Remove 'crs' field as it causes issues with collating
return default_collate(batch) I'm still getting the following feedback:
I think this could be related to when all the values in a batch label mask equal to 0, then it will be completely masked by |
How many classes do you have in your task? If you only have 2 classes and you ignore 1 of them, I wouldn't be surprised to see this. |
Hi @adamjstewart, thanks for your quick response. I have 18 classes but I'm trying to ignore class 0. I hope this attachment provides context: The prediction looks okay in some areas, so I'm wondering if I should just ignore the |
A couple of issues that may affect the usage of GridGeoSampler for benchmarking.
Overlapping patches
Back in #630, we modified GridGeoSampler to ensure that every part of the image is sampled from, even if the height/width of the image is not a multiple of the stride. At the time, I decided that we should adjust the stride of the last row/col in order to avoid sampling outside of the bounds of the image. In hindsight, I think this was a mistake.
The problem is that we end up with the last row/col overlapping with the second-to-last row/col, resulting in some areas being double counted when computing performance metrics. It also makes stitching together prediction patches unnecessarily complicated.
I think we should modify GridGeoSampler to avoid adjusting stride and instead sample outside the bounds of the image. I believe rasterio will simply return nodata pixels for areas outside of the image. @remtav are you okay with this solution? I believe this was actually the first idea you implemented, apologies for pushing that PR in the wrong direction.
Technically this issue also occurs when multiple images in the dataset intersect, but this is harder to mitigate without storing all predictions in one giant tensor and computing performance only on the final predicted mask. I think we would run out of memory very quickly.
ignore_index
weightingThis one may also affect training for other GeoSamplers as well, although I'm most concerned about evaluation.
When sampling from large tiles, many patches will contain partial or complete nodata pixels. TorchMetrics allows us to ignore these areas using
ignore_index
. However, it's unclear to me if all patches are weighted equally when computing the final performance metrics with Lightning. Ideally, the overall reported accuracy would match regardless of whether we chip up the image into small patches or if we compute accuracy on the entire image/mask in one go.We could peruse the internals of TorchMetrics and Lightning, but I think it's actually easier to construct a toy example to determine whether or not this issue occurs. Consider an image with width 200 and height 100. Let the first 99 columns of the ground truth mask be 0, the 100th column be 1, and the last 100 columns be 2. Let the predicted mask be a tensor of all 1s. If we use a GridGeoSampler with size 100 and stride 100, and let
ignore_index=0
, the correct performance should be ~1%. If the actual reported performance is 50%, we'll know we have an issue.The text was updated successfully, but these errors were encountered: