Skip to content

Commit

Permalink
atdpy: evaluate default field values for each object creation instead of
Browse files Browse the repository at this point in the history
sharing the same physical value across all objects of the same class.
Fixes #339
  • Loading branch information
mjambon committed May 10, 2023
1 parent 12ea5c5 commit 8bd56fb
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 24 deletions.
28 changes: 14 additions & 14 deletions atdpy/src/lib/Codegen.ml
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ methods and functions to convert data from/to JSON.

# Import annotations to allow forward references
from __future__ import annotations
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, NoReturn, Optional, Tuple, Union

import json
Expand Down Expand Up @@ -544,19 +544,16 @@ let rec type_name_of_expr env (e : type_expr) : string =
| Name (loc, (_, name, _::_), _) -> assert false
| Tvar (loc, _) -> not_implemented loc "type variables"

let rec get_default_default
?(mutable_ok = true) (e : type_expr) : string option =
let rec get_default_default (e : type_expr) : string option =
match e with
| Sum _
| Record _
| Tuple _ (* a default tuple could be possible but we're lazy *) -> None
| List _ ->
if mutable_ok then Some "[]"
else None
| List _ -> Some "[]"
| Option _
| Nullable _ -> Some "None"
| Shared (loc, e, an) -> get_default_default ~mutable_ok e
| Wrap (loc, e, an) -> get_default_default ~mutable_ok e
| Shared (loc, e, an) -> get_default_default e
| Wrap (loc, e, an) -> get_default_default e
| Name (loc, (loc2, name, []), an) ->
(match name with
| "unit" -> Some "None"
Expand All @@ -570,12 +567,11 @@ let rec get_default_default
| Name _ -> None
| Tvar _ -> None

let get_python_default
?mutable_ok (e : type_expr) (an : annot) : string option =
let get_python_default (e : type_expr) (an : annot) : string option =
let user_default = Python_annot.get_python_default an in
match user_default with
| Some s -> Some s
| None -> get_default_default ?mutable_ok e
| None -> get_default_default e

(* see explanation where this function is used *)
let has_no_class_inst_prop_default
Expand All @@ -584,7 +580,7 @@ let has_no_class_inst_prop_default
| Required -> true
| Optional -> (* default is None *) false
| With_default ->
match get_python_default ~mutable_ok:false e an with
match get_python_default e an with
| Some _ -> false
| None ->
(* There's either no default at all which is an error,
Expand Down Expand Up @@ -795,9 +791,13 @@ let inst_var_declaration
| Required -> ""
| Optional -> " = None"
| With_default ->
match get_python_default ~mutable_ok:false unwrapped_e an with
match get_python_default unwrapped_e an with
| None -> ""
| Some value -> sprintf " = %s" value
| Some x ->
(* This constructs ensures that a fresh default value is
evaluated for each class instanciation. It's important for
default lists since Python lists are mutable. *)
sprintf " = field(default_factory=lambda: %s)" x
in
[
Line (sprintf "%s: %s%s" var_name type_name default)
Expand Down
4 changes: 4 additions & 0 deletions atdpy/test/atd-input/everything.atd
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,7 @@ type recursive_class = {
flag: bool;
children: recursive_class list;
}

type default_list = {
~items: int list;
}
40 changes: 34 additions & 6 deletions atdpy/test/python-expected/everything.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

# Import annotations to allow forward references
from __future__ import annotations
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, NoReturn, Optional, Tuple, Union

import json
Expand Down Expand Up @@ -455,7 +455,7 @@ class IntFloatParametrizedRecord:
"""Original type: _int_float_parametrized_record = { ... }"""

field_a: int
field_b: List[float]
field_b: List[float] = field(default_factory=lambda: [])

@classmethod
def from_json(cls, x: Any) -> 'IntFloatParametrizedRecord':
Expand Down Expand Up @@ -489,7 +489,6 @@ class Root:
await_: bool
x___init__: float
items: List[List[int]]
extras: List[int]
aliased: Alias
point: Tuple[float, float]
kinds: List[Kind]
Expand All @@ -503,7 +502,8 @@ class Root:
parametrized_record: IntFloatParametrizedRecord
parametrized_tuple: KindParametrizedTuple
maybe: Optional[int] = None
answer: int = 42
extras: List[int] = field(default_factory=lambda: [])
answer: int = field(default_factory=lambda: 42)

@classmethod
def from_json(cls, x: Any) -> 'Root':
Expand All @@ -513,7 +513,6 @@ def from_json(cls, x: Any) -> 'Root':
await_=_atd_read_bool(x['await']) if 'await' in x else _atd_missing_json_field('Root', 'await'),
x___init__=_atd_read_float(x['__init__']) if '__init__' in x else _atd_missing_json_field('Root', '__init__'),
items=_atd_read_list(_atd_read_list(_atd_read_int))(x['items']) if 'items' in x else _atd_missing_json_field('Root', 'items'),
extras=_atd_read_list(_atd_read_int)(x['extras']) if 'extras' in x else [],
aliased=Alias.from_json(x['aliased']) if 'aliased' in x else _atd_missing_json_field('Root', 'aliased'),
point=(lambda x: (_atd_read_float(x[0]), _atd_read_float(x[1])) if isinstance(x, list) and len(x) == 2 else _atd_bad_json('array of length 2', x))(x['point']) if 'point' in x else _atd_missing_json_field('Root', 'point'),
kinds=_atd_read_list(Kind.from_json)(x['kinds']) if 'kinds' in x else _atd_missing_json_field('Root', 'kinds'),
Expand All @@ -527,6 +526,7 @@ def from_json(cls, x: Any) -> 'Root':
parametrized_record=IntFloatParametrizedRecord.from_json(x['parametrized_record']) if 'parametrized_record' in x else _atd_missing_json_field('Root', 'parametrized_record'),
parametrized_tuple=KindParametrizedTuple.from_json(x['parametrized_tuple']) if 'parametrized_tuple' in x else _atd_missing_json_field('Root', 'parametrized_tuple'),
maybe=_atd_read_int(x['maybe']) if 'maybe' in x else None,
extras=_atd_read_list(_atd_read_int)(x['extras']) if 'extras' in x else [],
answer=_atd_read_int(x['answer']) if 'answer' in x else 42,
)
else:
Expand All @@ -538,7 +538,6 @@ def to_json(self) -> Any:
res['await'] = _atd_write_bool(self.await_)
res['__init__'] = _atd_write_float(self.x___init__)
res['items'] = _atd_write_list(_atd_write_list(_atd_write_int))(self.items)
res['extras'] = _atd_write_list(_atd_write_int)(self.extras)
res['aliased'] = (lambda x: x.to_json())(self.aliased)
res['point'] = (lambda x: [_atd_write_float(x[0]), _atd_write_float(x[1])] if isinstance(x, tuple) and len(x) == 2 else _atd_bad_python('tuple of length 2', x))(self.point)
res['kinds'] = _atd_write_list((lambda x: x.to_json()))(self.kinds)
Expand All @@ -553,6 +552,7 @@ def to_json(self) -> Any:
res['parametrized_tuple'] = (lambda x: x.to_json())(self.parametrized_tuple)
if self.maybe is not None:
res['maybe'] = _atd_write_int(self.maybe)
res['extras'] = _atd_write_list(_atd_write_int)(self.extras)
res['answer'] = _atd_write_int(self.answer)
return res

Expand Down Expand Up @@ -683,3 +683,31 @@ def from_json_string(cls, x: str) -> 'Frozen':

def to_json_string(self, **kw: Any) -> str:
return json.dumps(self.to_json(), **kw)


@dataclass
class DefaultList:
"""Original type: default_list = { ... }"""

items: List[int] = field(default_factory=lambda: [])

@classmethod
def from_json(cls, x: Any) -> 'DefaultList':
if isinstance(x, dict):
return cls(
items=_atd_read_list(_atd_read_int)(x['items']) if 'items' in x else [],
)
else:
_atd_bad_json('DefaultList', x)

def to_json(self) -> Any:
res: Dict[str, Any] = {}
res['items'] = _atd_write_list(_atd_write_int)(self.items)
return res

@classmethod
def from_json_string(cls, x: str) -> 'DefaultList':
return cls.from_json(json.loads(x))

def to_json_string(self, **kw: Any) -> str:
return json.dumps(self.to_json(), **kw)
21 changes: 17 additions & 4 deletions atdpy/test/python-tests/test_atdpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,6 @@ def test_everything_to_json() -> None:
2
]
],
"extras": [
17,
53
],
"aliased": [
8,
9,
Expand Down Expand Up @@ -200,6 +196,10 @@ def test_everything_to_json() -> None:
"wow",
100
],
"extras": [
17,
53
],
"answer": 42
}"""
b_obj = e.Root.from_json_string(a_str)
Expand Down Expand Up @@ -256,5 +256,18 @@ def test_recursive_class() -> None:
assert b_str2 == a_str


def test_default_list() -> None:
a = e.DefaultList(items=[])
assert a.items == []
b = e.DefaultList()
assert b.items == []
c = e.DefaultList.from_json_string("{}")
assert c.items == []
# We could emit '{}' instead of '{"items": []}' but it's more complicated
# and not always desired.
j = b.to_json_string()
assert j == '{"items": []}'


# print updated json
test_everything_to_json()

0 comments on commit 8bd56fb

Please sign in to comment.