Skip to content

Commit

Permalink
Serde cleanup. (awslabs#441)
Browse files Browse the repository at this point in the history
* Serde cleanup.
  • Loading branch information
Jasper Schulz authored and Ayed committed Nov 29, 2019
1 parent c8ed59c commit 59cad51
Showing 1 changed file with 78 additions and 52 deletions.
130 changes: 78 additions & 52 deletions src/gluonts/core/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
import re
import textwrap
from functools import singledispatch
from pathlib import Path
from pathlib import PurePath
from pydoc import locate
from typing import Any, Optional
from typing import cast, Any, NamedTuple, Optional

# Third-party imports
import mxnet as mx
Expand Down Expand Up @@ -184,42 +184,58 @@ def _dump_code(x: Any) -> str:
# r = { 'class': ..., 'args': ... }
# r = { 'class': ..., 'kwargs': ... }
if type(x) == dict and x.get("__kind__") == kind_inst:
args = x["args"] if "args" in x else []
kwargs = x["kwargs"] if "kwargs" in x else {}
return "{fqname}({bindings})".format(
fqname=x["class"],
bindings=", ".join(
itertools.chain(
[_dump_code(v) for v in args],
[f"{k}={_dump_code(v)}" for k, v in kwargs.items()],
)
),
args = x.get("args", [])
kwargs = x.get("kwargs", {})

fqname = x["class"]
bindings = ", ".join(
itertools.chain(
map(_dump_code, args),
[f"{k}={_dump_code(v)}" for k, v in kwargs.items()],
)
)
return f"{fqname}({bindings})"

if type(x) == dict and x.get("__kind__") == kind_type:
return x["class"]

if isinstance(x, dict):
elems = [f"{_dump_code(k)}: {_dump_code(v)}" for k, v in x.items()]
return "{" + ", ".join(elems) + "}"
elif isinstance(x, list):
elems = [dump_code(v) for v in x]
return "[" + ", ".join(elems) + "]"
elif isinstance(x, tuple):
elems = [dump_code(v) for v in x]
return "(" + ", ".join(elems) + ",)"
elif isinstance(x, str):
inner = ", ".join(
f"{_dump_code(k)}: {_dump_code(v)}" for k, v in x.items()
)
return f"{{{inner}}}"

if isinstance(x, list):
inner = ", ".join(list(map(dump_code, x)))
return f"[{inner}]"

if isinstance(x, tuple):
inner = ", ".join(list(map(dump_code, x)))
# account for the extra `,` in `(x,)`
if len(x) == 1:
inner += ","
return f"({inner})"

if isinstance(x, str):
# json.dumps escapes the string
return json.dumps(x)
elif isinstance(x, float) or np.issubdtype(type(x), np.inexact):
return str(x) if math.isfinite(x) else 'float("' + str(x) + '")'
elif (
isinstance(x, int)
or np.issubdtype(type(x), np.integer)
or x is None
):

if isinstance(x, float) or np.issubdtype(type(x), np.inexact):
if math.isfinite(x):
return str(x)
else:
# e.g. `nan` needs to be encoded as `float("nan")`
return 'float("{x}")'

if isinstance(x, int) or np.issubdtype(type(x), np.integer):
return str(x)
else:
x = fqname_for(x.__class__)
raise RuntimeError(f"Unexpected element type {x}")

if x is None:
return str(x)

raise RuntimeError(
f"Unexpected element type {fqname_for(x.__class__)}"
)

return _dump_code(encode(o))

Expand Down Expand Up @@ -265,7 +281,8 @@ def _load_code(code: str, modules=None):
str(e),
)
if m:
name = m["module"] + "." + m["package"]
module, package = m["module"], m["package"]
name = f"{module}.{package}"
return _load_code(
code,
{**(modules or {}), name: importlib.import_module(name)},
Expand Down Expand Up @@ -389,43 +406,52 @@ def encode(v: Any) -> Any:
"""
if isinstance(v, type(None)):
return None
elif isinstance(v, (float, int, str)):

if isinstance(v, (float, int, str)):
return v
elif np.issubdtype(type(v), np.inexact):

if np.issubdtype(type(v), np.inexact):
return float(v)
elif np.issubdtype(type(v), np.integer):

if np.issubdtype(type(v), np.integer):
return int(v)
elif isinstance(v, (list, set)) or type(v) == tuple:
return [encode(v) for v in v]
elif isinstance(v, tuple) and not hasattr(v, "_asdict"):
return tuple([encode(v) for v in v])
elif isinstance(v, dict):
return {k: encode(v) for k, v in v.items()}
elif isinstance(v, type):
return {"__kind__": kind_type, "class": fqname_for(v)}
elif isinstance(v, tuple) and hasattr(v, "_asdict"):

# we have to check for namedtuples first, to encode them not as plain
# tuples (which would become lists)
if isinstance(v, tuple) and hasattr(v, "_asdict"):
v = cast(NamedTuple, v)
return {
"__kind__": kind_inst,
"class": fqname_for(v.__class__),
"kwargs": encode(getattr(v, "_asdict")()),
"kwargs": encode(v._asdict()),
}
elif hasattr(v, "__getnewargs_ex__"):
args, kwargs = getattr(v, "__getnewargs_ex__")()

if isinstance(v, (list, set, tuple)):
return list(map(encode, v))

if isinstance(v, dict):
return {k: encode(v) for k, v in v.items()}

if isinstance(v, type):
return {"__kind__": kind_type, "class": fqname_for(v)}

if hasattr(v, "__getnewargs_ex__"):
args, kwargs = v.__getnewargs_ex__() # mypy: ignore
return {
"__kind__": kind_inst,
"class": fqname_for(v.__class__),
"args": encode(args),
"kwargs": encode(kwargs),
}
else:
raise RuntimeError(bad_type_msg.format(fqname_for(v.__class__)))

raise RuntimeError(bad_type_msg.format(fqname_for(v.__class__)))


@encode.register(Path)
def encode_path(v: Path) -> Any:
@encode.register(PurePath)
def encode_path(v: PurePath) -> Any:
"""
Specializes :func:`encode` for invocations where ``v`` is an instance of
the :class:`~Path` class.
the :class:`~PurePath` class.
"""
return {
"__kind__": kind_inst,
Expand Down

0 comments on commit 59cad51

Please sign in to comment.