Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inference for other areas #20

Open
VaasuDevanS opened this issue Oct 23, 2024 · 0 comments
Open

Inference for other areas #20

VaasuDevanS opened this issue Oct 23, 2024 · 0 comments

Comments

@VaasuDevanS
Copy link

VaasuDevanS commented Oct 23, 2024

Thanks for sharing this work @speed8928. However, I couldn't use the provided pre-trained model and get meaningful prediction for another area. The issues are: 1. Output is not representative of the input image and 2. Output min and max values are -0.0128 and 0.304.

Below is the inference script I made. I am using Python 3.11.10 and torch 2.5.0.
Note: The only change I made in the codebase is updating line 31 in models/net.py to x = x_block0.view(-1, 440, 440)

sample.tif from step 2 is here: sample.zip (368KB)

import matplotlib.pyplot as plt
import numpy as np
import rasterio as rio
import torch

from src.models import modules, net, senet


def define_model(model_file):
    original_model = senet.senet154(pretrained='imagenet')
    encoder = modules.E_senet(original_model)
    model = net.Model(encoder, num_features=2048, block_channel=[256, 512, 1024, 2048])
    state_dict = torch.load(model_file, weights_only=True)['state_dict']
    model.load_state_dict(state_dict, strict=False)
    return model.to(device='cpu')


# 1. Load the model
dsm_model = define_model('Block0_skip_model_110.pth.tar')

# 2. Read the sample image and perform pre-processing
with rio.open('sample.tif') as src:
    image = src.read().astype('float32') / 255.0  # image.shape -> (3, 440, 440)

# 3. Get the prediction and perform post-processing
image_tensor = torch.as_tensor(np.expand_dims(image, axis=0)).to('cpu').float()
pred = dsm_model(image_tensor)
pred = torch.nn.functional.interpolate(pred, size=(440, 440), mode='bilinear')
output = pred.squeeze(axis=(0, 1)).detach().numpy()

# 4. Plot the image
fig, (ax0, ax1) = plt.subplots(1, 2)
ax0.imshow(np.transpose(image, (1, 2, 0)))
im = ax1.imshow(output)
plt.colorbar(im, ax=ax1, orientation='horizontal')
plt.savefig('plot.png', bbox_inches='tight')
plt.close()

plot

@VaasuDevanS VaasuDevanS changed the title Inference for another area Inference for other areas Nov 12, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant