Skip to content

Commit

Permalink
fix #234: Added method to fixup non json-spec compliant floats to mak…
Browse files Browse the repository at this point in the history
…e the resp… (#236)

* Added method to fixup non json-spec compliant floats to make the response json compliant
  • Loading branch information
dotgc authored and Jasper Schulz committed Aug 13, 2019
1 parent 3ad0216 commit ffe17d5
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 2 deletions.
28 changes: 27 additions & 1 deletion src/gluonts/shell/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

# Third-party imports
import requests
import numpy as np
from flask import Flask, Response, jsonify, request
from gunicorn.app.base import BaseApplication
from pydantic import BaseModel, BaseSettings
Expand Down Expand Up @@ -82,6 +83,31 @@ def log_throughput(instances, timings):
# list(zip(timings, item_lengths)


def jsonify_floats(json_object):
"""
Traverses through the JSON object and converts non JSON-spec compliant
floats(nan, -inf, inf) to their string representations.
Parameters
----------
json_object
JSON object
"""
if isinstance(json_object, dict):
return {k: jsonify_floats(v) for k, v in json_object.items()}
elif isinstance(json_object, list):
return [jsonify_floats(item) for item in json_object]
elif isinstance(json_object, float):
if np.isnan(json_object):
return "NaN"
elif np.isposinf(json_object):
return "Infinity"
elif np.isneginf(json_object):
return "-Infinity"
return json_object
return json_object


class Settings(BaseSettings):
# see: https://pydantic-docs.helpmanual.io/#settings
class Config:
Expand Down Expand Up @@ -191,7 +217,7 @@ def invocations() -> Response:

log_throughput(req.instances, forecasts.timings)

return jsonify(predictions=predictions)
return jsonify(predictions=jsonify_floats(predictions))

return app

Expand Down
23 changes: 22 additions & 1 deletion test/shell/test_shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# permissions and limitations under the License.

# Standard library imports
import json
from typing import ContextManager

# Third-party imports
Expand All @@ -23,7 +24,7 @@
from gluonts.model.trivial.mean import MeanPredictor
from gluonts.shell import testutil
from gluonts.shell.sagemaker import ServeEnv, TrainEnv
from gluonts.shell.serve import Settings
from gluonts.shell.serve import Settings, jsonify_floats
from gluonts.shell.train import run_train_and_test

context_length = 5
Expand Down Expand Up @@ -161,3 +162,23 @@ def test_dynamic_shell(
assert exp_samples_shape == act_samples.shape
assert equals(exp_mean, act_mean)
assert equals(exp_samples, act_samples)


def test_as_json_dict_outputs_valid_json():
non_compliant_json = {
"a": float("nan"),
"k": float("infinity"),
"b": {
"c": float("nan"),
"d": "testing",
"e": float("-infinity"),
"f": float("infinity"),
"g": {"h": float("nan")},
},
}

with pytest.raises(ValueError):
json.dumps(non_compliant_json, allow_nan=False)

output_json = jsonify_floats(non_compliant_json)
json.dumps(output_json, allow_nan=False)

0 comments on commit ffe17d5

Please sign in to comment.