11# SPDX-License-Identifier: Apache-2.0
22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3- """
4- This is a demo script showing how to use the
5- PrithviGeospatialMAE model with vLLM
6- This script is based on: https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11/blob/main/inference.py # noqa
7-
8- Target model weights: https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11/resolve/main/Prithvi-EO-V2-300M-TL-Sen1Floods11.pt # noqa
9-
10- The requirements for running this script are:
11- - Installing [terratorch, albumentations, rasterio] in your python environment
12- - downloading the model weights in a 'model' folder local to the script
13- (temporary measure until the proper config.json file is uploaded to HF)
14- - download an input example image (India_900498_S2Hand.tif) and place it in
15- the same folder with the script (or specify with the --data_file argument)
16-
17- Run the example:
18- python prithvi_geospatial_mae.py
19-
20- """ # noqa: E501
21-
223import argparse
234import datetime
245import os
6+ import re
257from typing import Union
268
279import albumentations
2810import numpy as np
2911import rasterio
30- import regex as re
3112import torch
3213from einops import rearrange
3314from terratorch .datamodules import Sen1Floods11NonGeoDataModule
3415
3516from vllm import LLM
3617
18+ torch .set_default_dtype (torch .float16 )
19+
3720NO_DATA = - 9999
3821NO_DATA_FLOAT = 0.0001
3922OFFSET = 0
4023PERCENTILE = 99
4124
42- model_config = """{
43- "architectures": ["PrithviGeoSpatialMAE"],
44- "num_classes": 0,
45- "pretrained_cfg": {
46- "task_args": {
47- "task": "SemanticSegmentationTask",
48- "model_factory": "EncoderDecoderFactory",
49- "loss": "ce",
50- "ignore_index": -1,
51- "lr": 0.001,
52- "freeze_backbone": false,
53- "freeze_decoder": false,
54- "plot_on_val": 10,
55- "optimizer": "AdamW",
56- "scheduler": "CosineAnnealingLR"
57- },
58- "model_args": {
59- "backbone_pretrained": false,
60- "backbone": "prithvi_eo_v2_300_tl",
61- "decoder": "UperNetDecoder",
62- "decoder_channels": 256,
63- "decoder_scale_modules": true,
64- "num_classes": 2,
65- "rescale": true,
66- "backbone_bands": [
67- "BLUE",
68- "GREEN",
69- "RED",
70- "NIR_NARROW",
71- "SWIR_1",
72- "SWIR_2"
73- ],
74- "head_dropout": 0.1,
75- "necks": [
76- {
77- "name": "SelectIndices",
78- "indices": [
79- 5,
80- 11,
81- 17,
82- 23
83- ]
84- },
85- {
86- "name": "ReshapeTokensToImage"
87- }
88- ]
89- },
90- "optimizer_params" : {
91- "lr": 5.0e-05,
92- "betas": [0.9, 0.999],
93- "eps": [1.0e-08],
94- "weight_decay": 0.05,
95- "amsgrad": false,
96- "maximize": false,
97- "capturable": false,
98- "differentiable": false
99- },
100- "scheduler_params" : {
101- "T_max": 50,
102- "eta_min": 0,
103- "last_epoch": -1,
104- "verbose": "deprecated"
105- }
106- },
107-
108-
109- "torch_dtype": "float32"
110- }
111- """
112-
113- # Temporarily creating the "config.json" for the model.
114- # This is going to disappear once the correct config.json is available on HF
115- with open (
116- os .path .join (os .path .dirname (__file__ ), "./model/config.json" ), "w"
117- ) as config_file :
118- config_file .write (model_config )
119-
12025datamodule_config = {
12126 "bands" : ["BLUE" , "GREEN" , "RED" , "NIR_NARROW" , "SWIR_1" , "SWIR_2" ],
12227 "batch_size" : 16 ,
13843
13944
14045class PrithviMAE :
141- def __init__ (self ):
142- print ("Initializing PrithviMAE model" )
143- self .llm = LLM (
144- model = os .path .join (os .path .dirname (__file__ ), "./model" ),
145- skip_tokenizer_init = True ,
146- dtype = "float32" ,
46+ def __init__ (self , model ):
47+ self .model = LLM (
48+ model = model , skip_tokenizer_init = True , dtype = "float16" , enforce_eager = True
14749 )
14850
14951 def run (self , input_data , location_coords ):
150- print ("################ Running inference on vLLM ##############" )
15152 # merge the inputs into one data structure
53+ if input_data is not None and input_data .dtype == torch .float32 :
54+ input_data = input_data .to (torch .float16 )
55+ input_data = input_data [0 ]
56+
15257 mm_data = {
153- "pixel_values" : torch .empty (0 ) if input_data is None else input_data ,
154- "location_coords" : torch .empty (0 )
155- if location_coords is None
156- else location_coords ,
58+ "pixel_values" : input_data ,
59+ "location_coords" : location_coords ,
15760 }
15861
15962 prompt = {"prompt_token_ids" : [1 ], "multi_modal_data" : mm_data }
160-
161- outputs = self .llm .encode (prompt , use_tqdm = False )
162- print ("################ Inference done (it took seconds) ##############" )
63+ outputs = self .model .encode (prompt , use_tqdm = False )
16364
16465 return outputs [0 ].outputs .data
16566
@@ -181,11 +82,12 @@ def process_channel_group(orig_img, channels):
18182 """
18283 Args:
18384 orig_img: torch.Tensor representing original image (reference)
184- with shape = (bands, H, W).
85+ with shape = (bands, H, W).
18586 channels: list of indices representing RGB channels.
18687
18788 Returns:
188- torch.Tensor with shape (num_channels, height, width) for original image
89+ torch.Tensor with shape (num_channels, height, width)
90+ for original image
18991 """
19092
19193 orig_img = orig_img [channels , ...]
@@ -260,10 +162,10 @@ def load_example(
260162
261163 Args:
262164 file_paths: list of file paths .
263- mean: list containing mean values for each band in the images
264- in *file_paths*.
265- std: list containing std values for each band in the images
266- in *file_paths*.
165+ mean: list containing mean values for each band in the
166+ images in *file_paths*.
167+ std: list containing std values for each band in the
168+ images in *file_paths*.
267169
268170 Returns:
269171 np.array containing created example
@@ -308,7 +210,7 @@ def load_example(
308210 print (f"Could not extract timestamp for { file } ({ e } )" )
309211
310212 imgs = np .stack (imgs , axis = 0 ) # num_frames, H, W, C
311- imgs = np .moveaxis (imgs , - 1 , 0 ).astype ("float32" )
213+ imgs = np .moveaxis (imgs , - 1 , 0 ).astype ("float32" ) # C, num_frames, H, W
312214 imgs = np .expand_dims (imgs , axis = 0 ) # add batch di
313215
314216 return imgs , temporal_coords , location_coords , metas
@@ -332,8 +234,10 @@ def run_model(
332234 )
333235
334236 # Build sliding window
237+
335238 batch_size = 1
336- batch = torch .tensor (input_data , device = "cpu" )
239+ # batch = torch.tensor(input_data, device="cpu")
240+ batch = torch .tensor (input_data )
337241 windows = batch .unfold (3 , img_size , img_size ).unfold (4 , img_size , img_size )
338242 h1 , w1 = windows .shape [3 :5 ]
339243 windows = rearrange (
@@ -344,34 +248,24 @@ def run_model(
344248 num_batches = windows .shape [0 ] // batch_size if windows .shape [0 ] > batch_size else 1
345249 windows = torch .tensor_split (windows , num_batches , dim = 0 )
346250
347- device = torch .device ("cuda" ) if torch .cuda .is_available () else torch .device ("cpu" )
348-
349251 if temporal_coords :
350- temporal_coords = torch .tensor (temporal_coords , device = device ).unsqueeze (0 )
252+ temporal_coords = torch .tensor (temporal_coords ).unsqueeze (0 )
351253 else :
352254 temporal_coords = None
353255 if location_coords :
354- location_coords = torch .tensor (location_coords [0 ], device = device ).unsqueeze (0 )
256+ location_coords = torch .tensor (location_coords [0 ]).unsqueeze (0 )
355257 else :
356258 location_coords = None
357259
358- # Run model
260+ # Run Prithvi-EO-V2-300M-TL-Sen1Floods11
359261 pred_imgs = []
360262 for x in windows :
361263 # Apply standardization
362264 x = datamodule .test_transform (image = x .squeeze ().numpy ().transpose (1 , 2 , 0 ))
363265 x = datamodule .aug (x )["image" ]
364266
365267 with torch .no_grad ():
366- x = x .to (device )
367268 pred = model .run (x , location_coords = location_coords )
368- if lightning_model :
369- pred_lightning = lightning_model (
370- x , temporal_coords = temporal_coords , location_coords = location_coords
371- )
372- pred_lightning = pred_lightning .output .detach ().cpu ()
373- if not torch .equal (pred , pred_lightning ):
374- print ("Inference output is not equal" )
375269 y_hat = pred .argmax (dim = 1 )
376270
377271 y_hat = torch .nn .functional .interpolate (
@@ -403,52 +297,18 @@ def run_model(
403297 return pred_imgs
404298
405299
406- def parse_args ():
407- parser = argparse .ArgumentParser ("MAE run inference" , add_help = False )
408-
409- parser .add_argument (
410- "--data_file" ,
411- type = str ,
412- default = "./India_900498_S2Hand.tif" ,
413- help = "Path to the file." ,
414- )
415- parser .add_argument (
416- "--output_dir" ,
417- type = str ,
418- default = "output" ,
419- help = "Path to the directory where to save outputs." ,
420- )
421- parser .add_argument (
422- "--input_indices" ,
423- default = [1 , 2 , 3 , 8 , 11 , 12 ],
424- type = int ,
425- nargs = "+" ,
426- help = "0-based indices of the six Prithvi channels to be selected from the "
427- "input. By default selects [1,2,3,8,11,12] for S2L1C data." ,
428- )
429- parser .add_argument (
430- "--rgb_outputs" ,
431- action = "store_true" ,
432- help = "If present, output files will only contain RGB channels. "
433- "Otherwise, all bands will be saved." ,
434- )
435-
436-
437300def main (
438301 data_file : str ,
302+ model : str ,
439303 output_dir : str ,
440304 rgb_outputs : bool ,
441305 input_indices : list [int ] = None ,
442306):
443307 os .makedirs (output_dir , exist_ok = True )
444308
445- # Load model ---------------------------------------------------------------
446-
447- model_obj = PrithviMAE ()
309+ model_obj = PrithviMAE (model = model )
448310 datamodule = generate_datamodule ()
449- img_size = 256 # Size of Sen1Floods11
450-
451- # Loading data -------------------------------------------------------------
311+ img_size = 512 # Size of Sen1Floods11
452312
453313 input_data , temporal_coords , location_coords , meta_data = load_example (
454314 file_paths = [data_file ],
@@ -460,16 +320,13 @@ def main(
460320 if input_data .mean () > 1 :
461321 input_data = input_data / 10000 # Convert to range 0-1
462322
463- # Running model ------------------------------------------------------------
464-
465323 channels = [
466324 datamodule_config ["bands" ].index (b ) for b in ["RED" , "GREEN" , "BLUE" ]
467325 ] # BGR -> RGB
468326
469327 pred = run_model (
470328 input_data , temporal_coords , location_coords , model_obj , datamodule , img_size
471329 )
472-
473330 # Save pred
474331 meta_data .update (count = 1 , dtype = "uint8" , compress = "lzw" , nodata = 0 )
475332 pred_file = os .path .join (
@@ -487,6 +344,7 @@ def main(
487344 orig_img = torch .Tensor (input_data [0 , :, 0 , ...]),
488345 channels = channels ,
489346 )
347+ rgb_orig = rgb_orig .to (torch .float32 )
490348
491349 pred [pred == 0.0 ] = np .nan
492350 img_pred = rgb_orig * 0.7 + pred * 0.3
@@ -503,9 +361,10 @@ def main(
503361
504362 # Save image rgb
505363 if rgb_outputs :
364+ name_suffix = os .path .splitext (os .path .basename (data_file ))[0 ]
506365 rgb_file = os .path .join (
507366 output_dir ,
508- f"original_rgb_{ os . path . splitext ( os . path . basename ( data_file ))[ 0 ] } .tiff" ,
367+ f"original_rgb_{ name_suffix } .tiff" ,
509368 )
510369 save_geotiff (
511370 image = _convert_np_uint8 (rgb_orig ),
@@ -515,6 +374,42 @@ def main(
515374
516375
517376if __name__ == "__main__" :
518- args = parse_args ()
377+ parser = argparse .ArgumentParser ("MAE run inference" , add_help = False )
378+
379+ parser .add_argument (
380+ "--data_file" ,
381+ type = str ,
382+ default = "./India_900498_S2Hand.tif" ,
383+ help = "Path to the file." ,
384+ )
385+ parser .add_argument (
386+ "--model" ,
387+ type = str ,
388+ default = "christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM" ,
389+ help = "Path to a checkpoint file to load from." ,
390+ )
391+ parser .add_argument (
392+ "--output_dir" ,
393+ type = str ,
394+ default = "output" ,
395+ help = "Path to the directory where to save outputs." ,
396+ )
397+ parser .add_argument (
398+ "--input_indices" ,
399+ default = [1 , 2 , 3 , 8 , 11 , 12 ],
400+ type = int ,
401+ nargs = "+" ,
402+ help = """
403+ 0-based indices of the six Prithvi channels to be selected from the input.
404+ By default selects [1,2,3,8,11,12] for S2L1C data.
405+ """ ,
406+ )
407+ parser .add_argument (
408+ "--rgb_outputs" ,
409+ action = "store_true" ,
410+ help = "If present, output files will only contain RGB channels. "
411+ "Otherwise, all bands will be saved." ,
412+ )
413+ args = parser .parse_args ()
519414
520415 main (** vars (args ))
0 commit comments