Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace tif file writer with MDS writer in pipeline #167

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 50 additions & 23 deletions scripts/pipeline/datacube.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@
Process Sentinel-2, Sentinel-1, and DEM data for a specified time range,
area of interest, and resolution.
"""
import os
import random
from datetime import timedelta
from pathlib import Path

import click
import geopandas as gpd
Expand All @@ -38,6 +40,7 @@
import xarray as xr
from pystac import ItemCollection
from shapely.geometry import box
from streaming.base import MDSWriter
from tile import tiler

STAC_API = "https://planetarycomputer.microsoft.com/api/stac/v1"
Expand All @@ -48,6 +51,11 @@
NODATA = 0
S1_MATCH_ATTEMPTS = 20
DATES_PER_LOCATION = 3
VERSION = "03"
MDS_COLUMNS = {"lat": "float32", "lon": "float32", "date": "str", "pixels": "ndarray"}
MDS_COMPRESSION = "zstd:9"
MDS_HASHES = ["sha1"]
MDS_LIMIT = "100MB"


def get_surrounding_days(reference, interval_days):
Expand Down Expand Up @@ -379,6 +387,10 @@ def process(
resolution,
)

if 0 in (dat.shape[0] for dat in result):
print("Pixels coverage does not overlap although bboxes do")
return None, None

return date, result


Expand Down Expand Up @@ -464,33 +476,48 @@ def main(sample, index, subset, bucket, localpath, dateranges):
random.seed(index)
random.shuffle(date_ranges)

match_count = 0
for date_range in date_ranges:
print(f"Processing data for date range {date_range}")
date, pixels = process(
tile.geometry,
date_range,
SPATIAL_RESOLUTION,
CLOUD_COVER_PERCENTAGE,
NODATA_PIXEL_PERCENTAGE,
)
if date is None:
continue
else:
match_count += 1
if not localpath:
outpath = f"s3://{bucket}/{VERSION}/{mgrs}"
else:
outpath = str(Path(localpath) / Path(f"{VERSION}/{mgrs}"))
os.makedirs(localpath, exist_ok=True)

with MDSWriter(
out=outpath,
columns=MDS_COLUMNS,
compression=MDS_COMPRESSION,
hashes=MDS_HASHES,
size_limit=MDS_LIMIT,
) as writer:
match_count = 0
for date_range in date_ranges:
print(f"Processing data for date range {date_range}")
date, pixels = process(
tile.geometry,
date_range,
SPATIAL_RESOLUTION,
CLOUD_COVER_PERCENTAGE,
NODATA_PIXEL_PERCENTAGE,
)
if date is None:
continue
else:
match_count += 1

if subset:
print(f"Subsetting to {subset}")
pixels = [
part[:, subset[1] : subset[3], subset[0] : subset[2]] for part in pixels
]
if subset:
print(f"Subsetting to {subset}")
pixels = [
part[:, subset[1] : subset[3], subset[0] : subset[2]]
for part in pixels
]

pixels = [part.compute() for part in pixels]
pixels = [part.compute() for part in pixels]

tiler(pixels, date, mgrs, bucket, localpath)
for sample in tiler(pixels, date):
writer.write(sample)

if match_count == DATES_PER_LOCATION:
break
if match_count == DATES_PER_LOCATION:
break

if not match_count:
raise ValueError("No matching data found")
Expand Down
62 changes: 9 additions & 53 deletions scripts/pipeline/tile.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,15 @@
It includes functions to filter tiles based on cloud coverage and no-data pixels,
and a tiling function that generates smaller tiles from the input stack.
"""
import subprocess
import tempfile

import numpy as np
import rasterio
import rioxarray # noqa: F401
import xarray as xr
from rasterio.enums import ColorInterp

NODATA = 0
TILE_SIZE = 512
PIXELS_PER_TILE = TILE_SIZE * TILE_SIZE
BAD_PIXEL_MAX_PERCENTAGE = 0.3
SCL_FILTER = [0, 1, 3, 8, 9, 10]
VERSION = "02"


def filter_clouds_nodata(tile):
Expand Down Expand Up @@ -55,19 +49,15 @@ def filter_clouds_nodata(tile):
return True # If both conditions pass


def tile_to_dir(stack, date, mgrs, bucket, dir):
def tiler(stack, date):
"""
Function to tile a multi-dimensional imagery stack while filtering out
tiles with high cloud coverage or no-data pixels.

Args:
- stack (xarray.Dataset): The input multi-dimensional imagery stack.
- date (str): Date string yyyy-mm-dd
- mgrs (str): MGRS Tile id
- bucket(str): AWS S3 bucket to write tiles to
"""
print("Writing tempfiles to ", dir)

# Calculate the number of full tiles in x and y directions
num_x_tiles = stack[0].x.size // TILE_SIZE
num_y_tiles = stack[0].y.size // TILE_SIZE
Expand Down Expand Up @@ -98,45 +88,11 @@ def tile_to_dir(stack, date, mgrs, bucket, dir):

tile = tile.drop_sel(band="SCL")

# Track band names and color interpretation
tile.attrs["long_name"] = [str(x.values) for x in tile.band]
color = [ColorInterp.blue, ColorInterp.green, ColorInterp.red] + [
ColorInterp.gray
] * (len(tile.band) - 3)

# Write tile to tempdir
name = "{dir}/claytile_{mgrs}_{date}_v{version}_{counter}.tif".format(
dir=dir,
mgrs=mgrs,
date=date.replace("-", ""),
version=VERSION,
counter=str(counter).zfill(4),
)
tile.rio.to_raster(name, compress="deflate")

with rasterio.open(name, "r+") as rst:
rst.colorinterp = color
rst.update_tags(date=date)
if bucket:
print(f"Syncing {dir} with s3://{bucket}/{VERSION}/{mgrs}/{date}")
subprocess.run(
[
"aws",
"s3",
"sync",
dir,
f"s3://{bucket}/{VERSION}/{mgrs}/{date}",
"--no-progress",
],
check=True,
)
else:
print("No bucket specified, skipping S3 sync.")


def tiler(stack, date, mgrs, bucket, dir):
if dir:
tile_to_dir(stack, date, mgrs, bucket, dir)
else:
with tempfile.TemporaryDirectory() as tmpdir:
tile_to_dir(stack, date, mgrs, bucket, tmpdir)
bounds = tile.rio.transform_bounds("EPSG:4326")

yield {
"pixels": tile.to_numpy(),
"date": date,
"lat": bounds[1] + (bounds[1] - bounds[3]) / 2,
"lon": bounds[0] + (bounds[0] - bounds[2]) / 2,
}