Skip to content

Commit

Permalink
refactor: added Argument type
Browse files Browse the repository at this point in the history
  • Loading branch information
15r10nk committed Nov 24, 2024
1 parent f8f4682 commit 0f6c1d1
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 28 deletions.
67 changes: 43 additions & 24 deletions src/inline_snapshot/_adapter/generic_call_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,31 @@
def get_adapter_for_type(typ):
subclasses = GenericCallAdapter.__subclasses__()
options = [cls for cls in subclasses if cls.check_type(typ)]
# print(typ,options)

if not options:
return

assert len(options) == 1
return options[0]


class Argument:
value: Any
is_default: bool = False

def __init__(self, value, is_default=False):
self.value = value
self.is_default = is_default


class GenericCallAdapter(Adapter):

@classmethod
def check_type(cls, typ) -> bool:
raise NotImplementedError(cls)

@classmethod
def arguments(cls, value) -> tuple[list[Any], dict[str, Any]]:
def arguments(cls, value) -> tuple[list[Argument], dict[str, Argument]]:
raise NotImplementedError(cls)

@classmethod
Expand All @@ -49,8 +58,10 @@ def repr(cls, value):

args, kwargs = cls.arguments(value)

arguments = [repr(value) for value in args] + [
f"{key}={repr(value)}" for key, value in kwargs.items()
arguments = [repr(value.value) for value in args] + [
f"{key}={repr(value.value)}"
for key, value in kwargs.items()
if not value.is_default
]

return f"{repr(type(value))}({', '.join(arguments)})"
Expand All @@ -59,8 +70,11 @@ def repr(cls, value):
def map(cls, value, map_function):
new_args, new_kwargs = cls.arguments(value)
return type(value)(
*[adapter_map(arg, map_function) for arg in new_args],
**{k: adapter_map(kwarg, map_function) for k, kwarg in new_kwargs.items()},
*[adapter_map(arg.value, map_function) for arg in new_args],
**{
k: adapter_map(kwarg.value, map_function)
for k, kwarg in new_kwargs.items()
},
)

def items(self, value, node):
Expand Down Expand Up @@ -114,12 +128,10 @@ def assign(self, old_value, old_node, new_value):
for i, (new_value_element, node) in enumerate(zip(new_args, old_node.args)):
old_value_element = self.argument(old_value, i)
result = yield from self.get_adapter(
old_value_element, new_value_element
).assign(old_value_element, node, new_value_element)
old_value_element, new_value_element.value
).assign(old_value_element, node, new_value_element.value)
result_args.append(result)

print(old_node.args)
print(new_args)
if len(old_node.args) > len(new_args):
for arg_pos, node in list(enumerate(old_node.args))[len(new_args) :]:
print("del", arg_pos)
Expand All @@ -138,14 +150,14 @@ def assign(self, old_value, old_node, new_value):
node=old_node,
arg_pos=insert_pos,
arg_name=None,
new_code=self.context._value_to_code(value),
new_value=value,
new_code=self.context._value_to_code(value.value),
new_value=value.value,
)

# keyword arguments
result_kwargs = {}
for kw in old_node.keywords:
if not kw.arg in new_kwargs:
if not kw.arg in new_kwargs or new_kwargs[kw.arg].is_default:
# delete entries
yield Delete(
"fix",
Expand All @@ -159,18 +171,20 @@ def assign(self, old_value, old_node, new_value):
to_insert = []
insert_pos = 0
for key, new_value_element in new_kwargs.items():
if new_value_element.is_default:
continue
if key not in old_node_kwargs:
# add new values
to_insert.append((key, new_value_element))
result_kwargs[key] = new_value_element
to_insert.append((key, new_value_element.value))
result_kwargs[key] = new_value_element.value
else:
node = old_node_kwargs[key]

# check values with same keys
old_value_element = self.argument(old_value, key)
result_kwargs[key] = yield from self.get_adapter(
old_value_element, new_value_element
).assign(old_value_element, node, new_value_element)
old_value_element, new_value_element.value
).assign(old_value_element, node, new_value_element.value)

if to_insert:
for key, value in to_insert:
Expand Down Expand Up @@ -219,17 +233,18 @@ def arguments(cls, value):
for field in fields(value): # type: ignore
if field.repr:
field_value = getattr(value, field.name)
is_default = False

if field.default != MISSING and field.default == field_value:
continue
is_default = True

if (
field.default_factory != MISSING
and field.default_factory() == field_value
):
continue
is_default = True

kwargs[field.name] = field_value
kwargs[field.name] = Argument(value=field_value, is_default=is_default)

return ([], kwargs)

Expand All @@ -256,13 +271,14 @@ def arguments(cls, value):
return (
[],
{
name: getattr(value, name)
name: Argument(value=getattr(value, name))
for name, info in value.model_fields.items()
if getattr(value, name) != info.default
},
)

def argument(self, value, pos_or_name):
@classmethod
def argument(cls, value, pos_or_name):
assert isinstance(pos_or_name, str)
return getattr(value, pos_or_name)

Expand Down Expand Up @@ -296,7 +312,7 @@ def arguments(cls, value: IsNamedTuple):
return (
[],
{
field: getattr(value, field)
field: Argument(value=getattr(value, field))
for field in value._fields
if field not in value._field_defaults
or getattr(value, field) != value._field_defaults[field]
Expand All @@ -316,7 +332,10 @@ def check_type(cls, value):
@classmethod
def arguments(cls, value: defaultdict):

return ([value.default_factory, dict(value)], {})
return (
[Argument(value=value.default_factory), Argument(value=dict(value))],
{},
)

def argument(self, value, pos_or_name):
assert isinstance(pos_or_name, int)
Expand Down
8 changes: 4 additions & 4 deletions tests/adapter/test_dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def test_remove_positional_argument():
"""\
from inline_snapshot import snapshot
from inline_snapshot._adapter.generic_call_adapter import GenericCallAdapter
from inline_snapshot._adapter.generic_call_adapter import GenericCallAdapter,Argument
class L:
Expand All @@ -274,7 +274,7 @@ def check_type(cls, typ):
@classmethod
def arguments(cls, value):
return (value.l,{})
return ([Argument(x) for x in value.l],{})
@classmethod
def argument(cls, value, pos_or_name):
Expand All @@ -294,7 +294,7 @@ def test_L2():
"test_something.py": """\
from inline_snapshot import snapshot
from inline_snapshot._adapter.generic_call_adapter import GenericCallAdapter
from inline_snapshot._adapter.generic_call_adapter import GenericCallAdapter,Argument
class L:
Expand All @@ -313,7 +313,7 @@ def check_type(cls, typ):
@classmethod
def arguments(cls, value):
return (value.l,{})
return ([Argument(x) for x in value.l],{})
@classmethod
def argument(cls, value, pos_or_name):
Expand Down

0 comments on commit 0f6c1d1

Please sign in to comment.