From 59cad516b2f322099eb23f30161d83da92800d31 Mon Sep 17 00:00:00 2001 From: Jasper Schulz Date: Mon, 11 Nov 2019 19:16:15 +0100 Subject: [PATCH] Serde cleanup. (#441) * Serde cleanup. --- src/gluonts/core/serde.py | 130 +++++++++++++++++++++++--------------- 1 file changed, 78 insertions(+), 52 deletions(-) diff --git a/src/gluonts/core/serde.py b/src/gluonts/core/serde.py index 639fd2e197..eccfbaff21 100644 --- a/src/gluonts/core/serde.py +++ b/src/gluonts/core/serde.py @@ -20,9 +20,9 @@ import re import textwrap from functools import singledispatch -from pathlib import Path +from pathlib import PurePath from pydoc import locate -from typing import Any, Optional +from typing import cast, Any, NamedTuple, Optional # Third-party imports import mxnet as mx @@ -184,42 +184,58 @@ def _dump_code(x: Any) -> str: # r = { 'class': ..., 'args': ... } # r = { 'class': ..., 'kwargs': ... } if type(x) == dict and x.get("__kind__") == kind_inst: - args = x["args"] if "args" in x else [] - kwargs = x["kwargs"] if "kwargs" in x else {} - return "{fqname}({bindings})".format( - fqname=x["class"], - bindings=", ".join( - itertools.chain( - [_dump_code(v) for v in args], - [f"{k}={_dump_code(v)}" for k, v in kwargs.items()], - ) - ), + args = x.get("args", []) + kwargs = x.get("kwargs", {}) + + fqname = x["class"] + bindings = ", ".join( + itertools.chain( + map(_dump_code, args), + [f"{k}={_dump_code(v)}" for k, v in kwargs.items()], + ) ) + return f"{fqname}({bindings})" + if type(x) == dict and x.get("__kind__") == kind_type: return x["class"] + if isinstance(x, dict): - elems = [f"{_dump_code(k)}: {_dump_code(v)}" for k, v in x.items()] - return "{" + ", ".join(elems) + "}" - elif isinstance(x, list): - elems = [dump_code(v) for v in x] - return "[" + ", ".join(elems) + "]" - elif isinstance(x, tuple): - elems = [dump_code(v) for v in x] - return "(" + ", ".join(elems) + ",)" - elif isinstance(x, str): + inner = ", ".join( + f"{_dump_code(k)}: {_dump_code(v)}" for k, v in x.items() + ) + return f"{{{inner}}}" + + if isinstance(x, list): + inner = ", ".join(list(map(dump_code, x))) + return f"[{inner}]" + + if isinstance(x, tuple): + inner = ", ".join(list(map(dump_code, x))) + # account for the extra `,` in `(x,)` + if len(x) == 1: + inner += "," + return f"({inner})" + + if isinstance(x, str): # json.dumps escapes the string return json.dumps(x) - elif isinstance(x, float) or np.issubdtype(type(x), np.inexact): - return str(x) if math.isfinite(x) else 'float("' + str(x) + '")' - elif ( - isinstance(x, int) - or np.issubdtype(type(x), np.integer) - or x is None - ): + + if isinstance(x, float) or np.issubdtype(type(x), np.inexact): + if math.isfinite(x): + return str(x) + else: + # e.g. `nan` needs to be encoded as `float("nan")` + return 'float("{x}")' + + if isinstance(x, int) or np.issubdtype(type(x), np.integer): return str(x) - else: - x = fqname_for(x.__class__) - raise RuntimeError(f"Unexpected element type {x}") + + if x is None: + return str(x) + + raise RuntimeError( + f"Unexpected element type {fqname_for(x.__class__)}" + ) return _dump_code(encode(o)) @@ -265,7 +281,8 @@ def _load_code(code: str, modules=None): str(e), ) if m: - name = m["module"] + "." + m["package"] + module, package = m["module"], m["package"] + name = f"{module}.{package}" return _load_code( code, {**(modules or {}), name: importlib.import_module(name)}, @@ -389,43 +406,52 @@ def encode(v: Any) -> Any: """ if isinstance(v, type(None)): return None - elif isinstance(v, (float, int, str)): + + if isinstance(v, (float, int, str)): return v - elif np.issubdtype(type(v), np.inexact): + + if np.issubdtype(type(v), np.inexact): return float(v) - elif np.issubdtype(type(v), np.integer): + + if np.issubdtype(type(v), np.integer): return int(v) - elif isinstance(v, (list, set)) or type(v) == tuple: - return [encode(v) for v in v] - elif isinstance(v, tuple) and not hasattr(v, "_asdict"): - return tuple([encode(v) for v in v]) - elif isinstance(v, dict): - return {k: encode(v) for k, v in v.items()} - elif isinstance(v, type): - return {"__kind__": kind_type, "class": fqname_for(v)} - elif isinstance(v, tuple) and hasattr(v, "_asdict"): + + # we have to check for namedtuples first, to encode them not as plain + # tuples (which would become lists) + if isinstance(v, tuple) and hasattr(v, "_asdict"): + v = cast(NamedTuple, v) return { "__kind__": kind_inst, "class": fqname_for(v.__class__), - "kwargs": encode(getattr(v, "_asdict")()), + "kwargs": encode(v._asdict()), } - elif hasattr(v, "__getnewargs_ex__"): - args, kwargs = getattr(v, "__getnewargs_ex__")() + + if isinstance(v, (list, set, tuple)): + return list(map(encode, v)) + + if isinstance(v, dict): + return {k: encode(v) for k, v in v.items()} + + if isinstance(v, type): + return {"__kind__": kind_type, "class": fqname_for(v)} + + if hasattr(v, "__getnewargs_ex__"): + args, kwargs = v.__getnewargs_ex__() # mypy: ignore return { "__kind__": kind_inst, "class": fqname_for(v.__class__), "args": encode(args), "kwargs": encode(kwargs), } - else: - raise RuntimeError(bad_type_msg.format(fqname_for(v.__class__))) + + raise RuntimeError(bad_type_msg.format(fqname_for(v.__class__))) -@encode.register(Path) -def encode_path(v: Path) -> Any: +@encode.register(PurePath) +def encode_path(v: PurePath) -> Any: """ Specializes :func:`encode` for invocations where ``v`` is an instance of - the :class:`~Path` class. + the :class:`~PurePath` class. """ return { "__kind__": kind_inst,