-
Notifications
You must be signed in to change notification settings - Fork 5
/
detect_sargassum.py
241 lines (165 loc) · 7.4 KB
/
detect_sargassum.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
# Computes a sargassum index (e.g. AFAI or ASI) from an L2A Sentinel-2 dataset
# =====================================================
import datetime
import glob
import os
import sys
import numpy as np
import rasterio
import Sentinel2
from AFAI import AFAI_Index
from ASI import ASI_Index
# ==============================================================================
def detect_sargassum(dataset_path, sarg_index, compute_kwargs={}, apply_mask=True, mask_keep_categs=[6], masked_value=np.nan, threshold=None, resolution="20", save_npy=False, save_geotiff=False, save_jp2=False, out_dir=None, verbose=True):
"""Computes a sargassum index from a Sentinel-2 dataset.
Parameters:
dataset_path : string
The path to the .SAFE directory containing an L2A Sentinel-2 dataset.
sarg_index : an instance of a class derived from Sargassum_Index
The sargassum index to use for the detection.
Options:
compute_kwargs : dict (default: empty dict)
A dictionary of additional parameters to be passed to sarg_index's
compute().
apply_mask : boolean (default: True)
Whether to apply the Sentinel-2 SCL mask to the data, resulting in masked pixels being marked as invalid values.
mask_keep_categs : list or tuple (default: [6])
The SCL categories to consider valid if apply_musk is True. Defaults to
only WATER.
masked_value : float or np.float (default: np.nan)
The value to use to indicate masked pixels in the output image.
threshold : numeric (default: None)
If given, the image output by the Sargassum_Index will be thresholded
using this value, with values below set to 0 and values above set to 1.
Masked pixels will be set to 2 in this cases.
resolution : numeric (default: 20)
The Sentinel-2 spatial resolution to use. All channels required by the
Sargassum_Index must be available at this resolution. Defaults to 20 m.
save_npy : boolean (default: False)
Whether to save the data in the output image as a numpy array.
save_geotiff : boolean (default: False)
Whether to save the output image in GeoTIFF format. Georeferecing data
will be copied over from the Sentinel-2 metadata.
save_jp2 : boolean (default: False)
Whether to save the output image in OpenJPEG2000 format. Georeferecing
data will be copied over from the Sentinel-2 metadata.
out_dir : string (default: "./")
The output directory where the results are to be saved. Defaults to the
current directory.
verbose : boolean (default: True)
Whether to report actions to the screen.
Returns:
result : numpy array
A numpy array of same shape as the channels with the result of the
sargassum detection
"""
# ----------------------------------------------------------------------------
# Parse tile and date from dataset name
basename = os.path.basename(os.path.normpath(dataset_path))
if verbose:
print("\nComputing sargassum index {} for Sentinel-2 dataset:\n{}".format(sarg_index.name, dataset_path))
date = datetime.datetime.strptime(basename[11:26], "%Y%m%dT%H%M%S")
tile = basename[39:44]
satellite = basename[1:3]
if verbose:
print("UTM/MGRS Tile: {}".format(tile))
print("Sensing Date: {} UTC".format(date.strftime("%Y-%m-%d %H:%M:%S")))
print("Satellite: Sentinel-{}".format(satellite))
# --------------------------------------------------
# Load required channels at requested resolution
img_data_path = Sentinel2.locate_data_path(dataset_path, resolution)
if verbose:
print("\nLoading channels {} ...".format(sarg_index.required_channels))
dataset = Sentinel2.load_channels(dataset_path, sarg_index.required_channels, resolution, verbose=verbose)
ch0 = dataset["channels"][sarg_index.required_channels[0]]
NX, NY = ch0.shape
NTOT = NX * NY
img_meta = dataset["meta"]
NCH = len(dataset)
# --------------------------------------------------
# Load SCL (if using mask)
if verbose:
print("\nLoading SCL mask ...")
SCL = Sentinel2.load_SCL(dataset_path, resolution)
SCL_mask = np.isin(SCL, mask_keep_categs)
if verbose:
mask_categs, counts = np.unique(SCL, return_counts=True)
mask_counts = dict(zip(mask_categs, counts))
print("Mask counts:", mask_counts)
mask_keep_count = np.count_nonzero(SCL_mask)
print("{:,} ({:.1f}%) pixels are unmasked".format(mask_keep_count, 100*mask_keep_count/SCL.size))
# --------------------------------------------------
# Compute index
if verbose:
print("\nComputing {} ...".format(sarg_index.name))
result = sarg_index.compute(dataset["channels"], **compute_kwargs)
# --------------------------------------------------
# Apply threshold, if requested
if threshold is not None:
if verbose:
print("\nApplying threshold of {} ...".format(threshold))
result = np.where(result >= threshold, 1, 0).astype("uint8")
# --------------------------------------------------
# Apply mask, if requested
if apply_mask:
if verbose:
print("\nApplying mask ...")
if threshold is not None:
result[~SCL_mask] = 2
else:
result[~SCL_mask] = masked_value
# --------------------------------------------------
# Save result to disk as a numpy array, if requested
if save_npy:
index_name = sarg_index.name.replace(" ", "_")
fname = "{}_{}_{}.npy".format(tile, date.strftime("%Y%m%d"), index_name)
if out_dir is None:
out_dir = dataset_path
out_path = os.path.join(out_dir, fname)
np.save(out_path, result)
if verbose:
print("\nWrote {}".format(out_path))
# --------------------------------------------------
# Save result to disk as GeoTIFF, if requested
if save_geotiff :
# Copy image metadata from SCL mask
img_meta['driver'] = "GTiff"
img_meta['dtype'] = result.dtype
img_meta['count'] = 1
index_name = sarg_index.name.replace(" ", "_")
fname = "{}_{}_{}.tif".format(tile, date.strftime("%Y%m%d"), index_name)
if out_dir is None:
out_dir = dataset_path
out_path = os.path.join(out_dir, fname)
with rasterio.open(out_path, "w", **img_meta) as fout:
fout.write(result, 1)
if verbose:
print("\nWrote {}".format(out_path))
# --------------------------------------------------
# Save result to disk as JPEG2000, if requested
if save_jp2 and result.dtype not in ["uint8"]:
print("\nWarning: can't save float image as JPEG2000; skipping")
save_jp2 = False
if save_jp2:
# Copy image metadata from SCL mask
img_meta['driver'] = "JP2OpenJPEG"
img_meta['dtype'] = result.dtype
index_name = sarg_index.name.replace(" ", "_")
fname = "{}_{}_{}.jp2".format(tile, date.strftime("%Y%m%d"), index_name)
if out_dir is None:
out_dir = dataset_path
out_path = os.path.join(out_dir, fname)
with rasterio.open(out_path, "w", **img_meta) as fout:
fout.write(result, 1)
if verbose:
print("\nWrote {}".format(out_path))
# --------------------------------------------------
return result
# =====================================================
# Example: compute ASI on the dataset passed as command-line arg
if __name__ == "__main__":
import sys
dataset_path = sys.argv[1]
# sarg_index = AFAI_Index()
sarg_index = ASI_Index(model_path="ASImodelColabv2.h5")
detect_sargassum(dataset_path, sarg_index, out_dir="./", apply_mask=False, mask_keep_categs=[6], save_npy=True, save_geotiff=True, save_jp2=False)