diff --git a/CHANGES.md b/CHANGES.md index 25a70aeb6..cef59ac8e 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -11,6 +11,7 @@ * renamed `OptionalHeaders`, `MimeTypes` and `ImageDrivers` enums to the singular form. (https://github.com/developmentseed/titiler/pull/258) * renamed `MimeType` to `MediaType` (https://github.com/developmentseed/titiler/pull/258) +* add `ColorMapParams` dependency to ease the creation of custom colormap dependency (https://github.com/developmentseed/titiler/pull/252) ## 0.1.0 (2021-02-17) diff --git a/docs/concepts/dependencies.md b/docs/concepts/dependencies.md index e30bfbba5..4e6e4e4b6 100644 --- a/docs/concepts/dependencies.md +++ b/docs/concepts/dependencies.md @@ -162,15 +162,27 @@ The `factories` allow users to set multiple default dependencies. Here is the li title="Color Formula", description="rio-color formula (info: https://github.com/mapbox/rio-color)", ) - color_map: Optional[ColorMapNames] = Query( - None, description="rio-tiler's colormap name" - ) return_mask: bool = Query(True, description="Add mask to the output data.") - colormap: Optional[Dict[int, Tuple[int, int, int, int]]] = field(init=False) + + rescale_range: Optional[List[Union[float, int]]] = field(init=False) def __post_init__(self): """Post Init.""" - self.colormap = cmap.get(self.color_map.value) if self.color_map else None + self.rescale_range = ( + list(map(float, self.rescale.split(","))) if self.rescale else None + ) + ``` + +* **colormap_dependency**: colormap options. + + ```python + def ColorMapParams( + color_map: ColorMapNames = Query(None, description="Colormap name",) + ) -> Optional[Dict]: + """Colormap Dependency.""" + if color_map: + return cmap.get(color_map.value) + return None ``` * **additional_dependency**: Default dependency, will be passed as `**kwargs` to all endpoints. diff --git a/tests/test_CustomCmap.py b/tests/test_CustomCmap.py new file mode 100644 index 000000000..0e1311282 --- /dev/null +++ b/tests/test_CustomCmap.py @@ -0,0 +1,56 @@ +# """Test TiTiler Custom Colormap Params.""" + +from enum import Enum +from io import BytesIO +from typing import Dict, Optional + +import numpy +from rio_tiler.colormap import ColorMaps + +from titiler.endpoints import factory + +from .conftest import DATA_DIR + +from fastapi import FastAPI, Query + +from starlette.testclient import TestClient + +cmap_values = { + "cmap1": {6: (4, 5, 6, 255)}, +} +cmap = ColorMaps(data=cmap_values) +ColorMapNames = Enum( # type: ignore + "ColorMapNames", [(a, a) for a in sorted(cmap.list())] +) + + +def ColorMapParams( + color_map: ColorMapNames = Query(None, description="Colormap name",) +) -> Optional[Dict]: + """Colormap Dependency.""" + if color_map: + return cmap.get(color_map.value) + return None + + +def test_CustomCmap(): + """Test Custom Render Params dependency.""" + app = FastAPI() + cog = factory.TilerFactory(colormap_dependency=ColorMapParams) + app.include_router(cog.router) + client = TestClient(app) + + response = client.get( + f"/preview.npy?url={DATA_DIR}/above_cog.tif&bidx=1&color_map=cmap1" + ) + assert response.status_code == 200 + assert response.headers["content-type"] == "application/x-binary" + data = numpy.load(BytesIO(response.content)) + assert 4 in data[0] + assert 5 in data[1] + assert 6 in data[2] + + response = client.get( + f"/preview.npy?url={DATA_DIR}/above_cog.tif&bidx=1&color_map=another_cmap" + ) + assert response.status_code == 422 diff --git a/titiler/dependencies.py b/titiler/dependencies.py index a24d8cdf9..2111a26cf 100644 --- a/titiler/dependencies.py +++ b/titiler/dependencies.py @@ -3,7 +3,7 @@ import re from dataclasses import dataclass, field from enum import Enum -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Union import numpy from morecantile import tms @@ -68,6 +68,15 @@ def TMSParams( return tms.get(TileMatrixSetId.name) +def ColorMapParams( + color_map: ColorMapNames = Query(None, description="Colormap name",) +) -> Optional[Dict]: + """Colormap Dependency.""" + if color_map: + return cmap.get(color_map.value) + return None + + @dataclass class DefaultDependency: """Dependency Base Class""" @@ -327,17 +336,12 @@ class RenderParams(DefaultDependency): title="Color Formula", description="rio-color formula (info: https://github.com/mapbox/rio-color)", ) - color_map: Optional[ColorMapNames] = Query( - None, description="rio-tiler's colormap name" - ) return_mask: bool = Query(True, description="Add mask to the output data.") - colormap: Optional[Dict[int, Tuple[int, int, int, int]]] = field(init=False) rescale_range: Optional[List[Union[float, int]]] = field(init=False) def __post_init__(self): """Post Init.""" - self.colormap = cmap.get(self.color_map.value) if self.color_map else None self.rescale_range = ( list(map(float, self.rescale.split(","))) if self.rescale else None ) diff --git a/titiler/endpoints/factory.py b/titiler/endpoints/factory.py index 77d9e74b5..881d4a64d 100644 --- a/titiler/endpoints/factory.py +++ b/titiler/endpoints/factory.py @@ -24,6 +24,7 @@ BandsParams, BidxExprParams, BidxParams, + ColorMapParams, DatasetParams, DefaultDependency, ImageParams, @@ -86,6 +87,8 @@ class BaseTilerFactory(metaclass=abc.ABCMeta): # Image rendering Dependencies render_dependency: Type[DefaultDependency] = RenderParams + colormap_dependency: Callable[..., Optional[Dict]] = ColorMapParams + # TileMatrixSet dependency tms_dependency: Callable[..., TileMatrixSet] = WebMercatorTMSParams @@ -298,6 +301,7 @@ def tile( layer_params=Depends(self.layer_dependency), dataset_params=Depends(self.dataset_dependency), render_params=Depends(self.render_dependency), + colormap=Depends(self.colormap_dependency), kwargs: Dict = Depends(self.additional_dependency), ): """Create map tile from a dataset.""" @@ -338,7 +342,7 @@ def tile( content = image.render( add_mask=render_params.return_mask, img_format=format.driver, - colormap=render_params.colormap or dst_colormap, + colormap=colormap or dst_colormap, **format.profile, **render_params.kwargs, ) @@ -385,6 +389,7 @@ def tilejson( layer_params=Depends(self.layer_dependency), # noqa dataset_params=Depends(self.dataset_dependency), # noqa render_params=Depends(self.render_dependency), # noqa + colormap=Depends(self.colormap_dependency), # noqa kwargs: Dict = Depends(self.additional_dependency), # noqa ): """Return TileJSON document for a dataset.""" @@ -452,6 +457,7 @@ def wmts( layer_params=Depends(self.layer_dependency), # noqa dataset_params=Depends(self.dataset_dependency), # noqa render_params=Depends(self.render_dependency), # noqa + colormap=Depends(self.colormap_dependency), # noqa kwargs: Dict = Depends(self.additional_dependency), # noqa ): """OGC WMTS endpoint.""" @@ -572,6 +578,7 @@ def preview( img_params=Depends(self.img_dependency), dataset_params=Depends(self.dataset_dependency), render_params=Depends(self.render_dependency), + colormap=Depends(self.colormap_dependency), kwargs: Dict = Depends(self.additional_dependency), ): """Create preview of a dataset.""" @@ -587,9 +594,7 @@ def preview( **dataset_params.kwargs, **kwargs, ) - colormap = render_params.colormap or getattr( - src_dst, "colormap", None - ) + colormap = colormap or getattr(src_dst, "colormap", None) timings.append(("dataread", round(t.elapsed * 1000, 2))) if not format: @@ -643,6 +648,7 @@ def part( image_params=Depends(self.img_dependency), dataset_params=Depends(self.dataset_dependency), render_params=Depends(self.render_dependency), + colormap=Depends(self.colormap_dependency), kwargs: Dict = Depends(self.additional_dependency), ): """Create image from part of a dataset.""" @@ -659,9 +665,7 @@ def part( **dataset_params.kwargs, **kwargs, ) - colormap = render_params.colormap or getattr( - src_dst, "colormap", None - ) + colormap = colormap or getattr(src_dst, "colormap", None) timings.append(("dataread", round(t.elapsed * 1000, 2))) with utils.Timer() as t: @@ -1082,6 +1086,7 @@ def tile( layer_params=Depends(self.layer_dependency), dataset_params=Depends(self.dataset_dependency), render_params=Depends(self.render_dependency), + colormap=Depends(self.colormap_dependency), pixel_selection: PixelSelectionMethod = Query( PixelSelectionMethod.first, description="Pixel selection method." ), @@ -1132,7 +1137,7 @@ def tile( content = image.render( add_mask=render_params.return_mask, img_format=format.driver, - colormap=render_params.colormap, + colormap=colormap, **format.profile, **render_params.kwargs, ) @@ -1182,6 +1187,7 @@ def tilejson( layer_params=Depends(self.layer_dependency), # noqa dataset_params=Depends(self.dataset_dependency), # noqa render_params=Depends(self.render_dependency), # noqa + colormap=Depends(self.colormap_dependency), # noqa pixel_selection: PixelSelectionMethod = Query( PixelSelectionMethod.first, description="Pixel selection method." ), # noqa @@ -1247,6 +1253,7 @@ def wmts( layer_params=Depends(self.layer_dependency), # noqa dataset_params=Depends(self.dataset_dependency), # noqa render_params=Depends(self.render_dependency), # noqa + colormap=Depends(self.colormap_dependency), # noqa pixel_selection: PixelSelectionMethod = Query( PixelSelectionMethod.first, description="Pixel selection method." ), # noqa