From f9b3b879b4413c2235681d8ea6f4bdc9a12bf332 Mon Sep 17 00:00:00 2001 From: Kyle Wilcox Date: Wed, 8 Feb 2023 11:33:22 -0500 Subject: [PATCH] Allow NaN attributes in datasets (#152) * Allow NaN attributes in datasets Datasets with NaN value attributes currently raise a `ValueError: Out of range float values are not JSON compliant` error. This adds a custom starlette JsonResponse class that explicitly allows nan values. * Allow NaN values in all plugins/routes * Fix pre-commit errors --------- Co-authored-by: Joe Hamman --- tests/conftest.py | 2 ++ tests/test_fsspec_compat.py | 2 +- tests/test_rest_api.py | 6 +++++- xpublish/plugins/included/dataset_info.py | 8 +++++--- xpublish/plugins/included/zarr.py | 11 ++++++----- xpublish/utils/api.py | 16 +++++++++++++++- 6 files changed, 34 insertions(+), 11 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index c652abae..18cb3c69 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,6 +7,8 @@ def airtemp_ds(): ds = xr.tutorial.open_dataset('air_temperature') ds['air'].encoding['_FillValue'] = -9999 + ds['air'].attrs['nan_attribute'] = np.nan + ds['air'].attrs['none_attribute'] = None return ds diff --git a/tests/test_fsspec_compat.py b/tests/test_fsspec_compat.py index 1af65058..b733c2ca 100644 --- a/tests/test_fsspec_compat.py +++ b/tests/test_fsspec_compat.py @@ -12,7 +12,7 @@ def test_get_zmetadata_key(airtemp_ds): mapper = TestMapper(SingleDatasetRest(airtemp_ds).app) actual = json.loads(mapper['.zmetadata'].decode()) expected = jsonify_zmetadata(airtemp_ds, create_zmetadata(airtemp_ds)) - assert actual == expected + assert json.dumps(actual, allow_nan=True) == json.dumps(expected, allow_nan=True) def test_missing_key_raises_keyerror(airtemp_ds): diff --git a/tests/test_rest_api.py b/tests/test_rest_api.py index b776212a..dfbce26d 100644 --- a/tests/test_rest_api.py +++ b/tests/test_rest_api.py @@ -1,3 +1,5 @@ +import json + import pytest import uvicorn import xarray as xr @@ -288,7 +290,9 @@ def test_repr(airtemp_ds, airtemp_app_client): def test_zmetadata(airtemp_ds, airtemp_app_client): response = airtemp_app_client.get('/.zmetadata') assert response.status_code == 200 - assert response.json() == jsonify_zmetadata(airtemp_ds, create_zmetadata(airtemp_ds)) + assert json.dumps(response.json()) == json.dumps( + jsonify_zmetadata(airtemp_ds, create_zmetadata(airtemp_ds)) + ) def test_bad_key(airtemp_app_client): diff --git a/xpublish/plugins/included/dataset_info.py b/xpublish/plugins/included/dataset_info.py index 1b4d3fbf..49b39073 100644 --- a/xpublish/plugins/included/dataset_info.py +++ b/xpublish/plugins/included/dataset_info.py @@ -5,6 +5,8 @@ from starlette.responses import HTMLResponse # type: ignore from zarr.storage import attrs_key # type: ignore +from xpublish.utils.api import JSONResponse + from ...dependencies import get_zmetadata, get_zvariables from .. import Dependencies, Plugin, hookimpl @@ -32,13 +34,13 @@ def html_representation( def list_keys( dataset=Depends(deps.dataset), ): - return list(dataset.variables) + return JSONResponse(list(dataset.variables)) @router.get('/dict') def to_dict( dataset=Depends(deps.dataset), ): - return dataset.to_dict(data=False) + return JSONResponse(dataset.to_dict(data=False)) @router.get('/info') def info( @@ -67,6 +69,6 @@ def info( info['global_attributes'] = meta[attrs_key] - return info + return JSONResponse(info) return router diff --git a/xpublish/plugins/included/zarr.py b/xpublish/plugins/included/zarr.py index 9a8491c0..8e4ee243 100644 --- a/xpublish/plugins/included/zarr.py +++ b/xpublish/plugins/included/zarr.py @@ -1,4 +1,3 @@ -import json import logging from typing import Sequence @@ -8,6 +7,8 @@ from starlette.responses import Response # type: ignore from zarr.storage import array_meta_key, attrs_key, group_meta_key # type: ignore +from xpublish.utils.api import JSONResponse + from ...dependencies import get_zmetadata, get_zvariables from ...utils.api import DATASET_ID_ATTR_KEY from ...utils.cache import CostTimer @@ -39,7 +40,7 @@ def get_zarr_metadata( zjson = jsonify_zmetadata(dataset, zmetadata) - return Response(json.dumps(zjson).encode('ascii'), media_type='application/json') + return JSONResponse(zjson) @router.get(f'/{group_meta_key}') def get_zarr_group( @@ -49,7 +50,7 @@ def get_zarr_group( zvariables = get_zvariables(dataset, cache) zmetadata = get_zmetadata(dataset, cache, zvariables) - return zmetadata['metadata'][group_meta_key] + return JSONResponse(zmetadata['metadata'][group_meta_key]) @router.get(f'/{attrs_key}') def get_zarr_attrs( @@ -59,7 +60,7 @@ def get_zarr_attrs( zvariables = get_zvariables(dataset, cache) zmetadata = get_zmetadata(dataset, cache, zvariables) - return zmetadata['metadata'][attrs_key] + return JSONResponse(zmetadata['metadata'][attrs_key]) @router.get('/{var}/{chunk}') def get_variable_chunk( @@ -80,7 +81,7 @@ def get_variable_chunk( if array_meta_key in chunk: return zmetadata['metadata'][f'{var}/{array_meta_key}'] elif attrs_key in chunk: - return zmetadata['metadata'][f'{var}/{attrs_key}'] + return JSONResponse(zmetadata['metadata'][f'{var}/{attrs_key}']) elif group_meta_key in chunk: raise HTTPException(status_code=404, detail='No subgroups') else: diff --git a/xpublish/utils/api.py b/xpublish/utils/api.py index d6c23ff5..470b1fa3 100644 --- a/xpublish/utils/api.py +++ b/xpublish/utils/api.py @@ -1,9 +1,11 @@ +import json from collections.abc import Mapping -from typing import Dict, List, Tuple +from typing import Any, Dict, List, Tuple import xarray as xr from fastapi import APIRouter from fastapi.openapi.utils import get_openapi +from starlette.responses import JSONResponse as StarletteJSONResponse # type: ignore DATASET_ID_ATTR_KEY = '_xpublish_id' @@ -118,3 +120,15 @@ def openapi(self): self._app.openapi_schema = openapi_schema return self._app.openapi_schema + + +class JSONResponse(StarletteJSONResponse): + def __init__(self, *args, **kwargs): + self._render_kwargs = dict( + ensure_ascii=True, allow_nan=True, indent=None, separators=(',', ':') + ) + self._render_kwargs.update(kwargs.pop('render_kwargs', {})) + super().__init__(*args, **kwargs) + + def render(self, content: Any) -> bytes: + return json.dumps(content, **self._render_kwargs).encode('utf-8')