diff --git a/pyproject.toml b/pyproject.toml index 8095126..e10b196 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "typegraph3" -version = "0.1.2rc1" +version = "0.1.2rc2" description = "Type Auto Switch" authors = [ {name = "luxuncang", email = "luxuncang@qq.com"}, diff --git a/src/typegraph/converter/base.py b/src/typegraph/converter/base.py index ed66ba2..b65cf5f 100644 --- a/src/typegraph/converter/base.py +++ b/src/typegraph/converter/base.py @@ -10,20 +10,24 @@ Set, Iterable, Iterator, + Sequence, + MutableSequence, cast, Type, Awaitable, Any, Optional, + Mapping, + MutableMapping, get_type_hints, - Generic, + is_typeddict, ) from functools import wraps, reduce import networkx as nx from typing_extensions import get_args, get_origin from typing_inspect import is_union_type, get_generic_type -from typeguard import check_type, TypeCheckError, CollectionCheckStrategy +from typeguard import check_type, CollectionCheckStrategy from .typevar import iter_deep_type, gen_typevar_model, extract_typevar_mapping from ..type_utils import ( @@ -32,6 +36,8 @@ is_protocol_type, check_protocol_type, get_subclass_types, + get_connected_subgraph, + iter_type_args, ) @@ -50,6 +56,7 @@ def __init__(self): self.pG = nx.DiGraph() self.tG = nx.DiGraph() self.pmG = nx.DiGraph() + self.qG = nx.DiGraph() TypeConverter.instances.append(self) def get_graph( @@ -96,8 +103,6 @@ def _gen_edge( metadata={"protocol": True}, ) for p_type in self.get_protocol_types(out_type): - if out_type == str: - print(p_type, in_type, out_type, converter) self.pG.add_edge( out_type, p_type, @@ -110,6 +115,11 @@ def _gen_graph(self, in_type: Type[In], out_type: Type[Out]): tmp_G = nx.DiGraph() im = gen_typevar_model(in_type) om = gen_typevar_model(out_type) + + def _gen_sub_graph(mapping, node): + for su, sv, sc in get_connected_subgraph(self.tG, node).edges(data=True): + tmp_G.add_edge(su.get_instance(mapping), sv.get_instance(mapping), **sc) + for u, v, c in self.tG.edges(data=True): um = gen_typevar_model(u) vm = gen_typevar_model(v) @@ -120,8 +130,27 @@ def _gen_graph(self, in_type: Type[In], out_type: Type[Out]): tmp_G.add_edge( um.get_instance(mapping), vm.get_instance(mapping), **c ) + _gen_sub_graph(mapping, t) except Exception: ... + + for su, sv, sc in self.G.edges(data=True): + su_m = gen_typevar_model(su) + for arg in iter_type_args(su_m): + try: + mapping = extract_typevar_mapping(um, arg) + sub_m = su_m.replace_args( + arg, + gen_typevar_model(vm.get_instance(mapping)), # type: ignore + ).get_instance() + tmp_G.add_edge( + sub_m, + sv, + **sc, + ) + except Exception: + ... + self.qG = nx.compose(self.qG, tmp_G) return tmp_G def register_generic_converter(self, input_type: Type, out_type: Type): @@ -333,6 +362,8 @@ def __iter_func_dict(item): res = map(_iter_func, input_value) elif in_origin == dict or out_origin == Dict: res = dict(map(__iter_func_dict, input_value.items())) + elif out_origin in (Mapping, MutableMapping): + res = dict(map(__iter_func_dict, input_value.items())) else: raise ValueError( f"Unsupported structural_type {input_type} to {out_type}" @@ -412,6 +443,11 @@ async def __iter_func_dict(item): *map(__iter_func_dict, input_value.items()) ) res = dict(items) + elif out_origin in (Mapping, MutableMapping): + items = await asyncio.gather( + *map(__iter_func_dict, input_value.items()) + ) + res = dict(items) else: raise ValueError( f"Unsupported structural_type {input_type} to {out_type}" @@ -497,17 +533,38 @@ def get_edges(self, sub_class: bool = False, protocol: bool = False): ): yield edge - def show_mermaid_graph(self, sub_class: bool = False, protocol: bool = False): + def show_mermaid_graph( + self, sub_class: bool = False, protocol: bool = False, full: bool = False + ): from IPython.display import display, Markdown + import typing + + nodes = [] + + def get_name(cls): + if type(cls) in (typing._GenericAlias, typing.GenericAlias): # type: ignore + return str(cls) + elif hasattr(cls, "__name__"): + return cls.__name__ + return str(cls) + + def get_node_name(cls): + return f"node{nodes.index(cls)}" text = "```mermaid\ngraph TD;\n" - for edge in self.get_edges(sub_class=sub_class, protocol=protocol): - line_style = "--" if edge[2]["line"] else "-.-" - text += f"{edge[0].__name__}{line_style}>{edge[1].__name__}\n" + for edge in self.get_graph( + sub_class=sub_class, protocol=protocol, combos=[self.qG] if full else None + ).edges(data=True): + if edge[0] not in nodes: + nodes.append(edge[0]) + if edge[1] not in nodes: + nodes.append(edge[1]) + line_style = "--" if edge[2].get("line", False) else "-.-" + text += f'{get_node_name(edge[0])}["{get_name(edge[0])}"] {line_style}> {get_node_name(edge[1])}["{get_name(edge[1])}"]\n' text += "```" display(Markdown(text)) - return text + # return text def get_all_paths( self, @@ -566,13 +623,21 @@ def get_check_types_by_value( for edge in G.edges(): if edge[0] in nodes: continue - nodes.add(edge[0]) - try: - check_type( - input_value, - edge[0], - collection_check_strategy=CollectionCheckStrategy.ALL_ITEMS, - ) - except Exception: - continue - yield edge[0] + if get_origin(edge[0]) in ( + Iterable, + Iterator, + Mapping, + MutableMapping, + Sequence, + MutableSequence, + ) or is_typeddict(edge[0]): + nodes.add(edge[0]) + try: + check_type( + input_value, + edge[0], + collection_check_strategy=CollectionCheckStrategy.ALL_ITEMS, + ) + except Exception: + continue + yield edge[0] diff --git a/src/typegraph/converter/typevar.py b/src/typegraph/converter/typevar.py index deb9163..c444925 100644 --- a/src/typegraph/converter/typevar.py +++ b/src/typegraph/converter/typevar.py @@ -1,3 +1,4 @@ +from __future__ import annotations from typing import ( TypeVar, List, @@ -12,7 +13,7 @@ from typing_inspect import is_typevar from ..type_utils import ( - get_origin as get_real_origin, + get_real_origin, generate_type, ) @@ -54,6 +55,33 @@ def get_instance(self, instance: Optional[dict[Type[TypeVar], Type]] = None): raise ValueError("Invalid TypeVarModel") return generate_type(generic, args_list) + def replace_args(self, source: TypeVarModel, target: TypeVarModel) -> 'TypeVarModel': + if self.args is None: + return TypeVarModel(self.origin) + + new_args = [] + for arg in self.args: + if isinstance(arg, TypeVarModel): + if arg == source: + new_args.append(target) + else: + new_args.append(arg.replace_args(source, target)) + elif isinstance(arg, list): + new_list = [] + for item in arg: + if isinstance(item, TypeVarModel): + if item == source: + new_list.append(target) + else: + new_list.append(item.replace_args(source, target)) + else: + new_list.append(item) + new_args.append(new_list) + else: + new_args.append(arg) + + return TypeVarModel(self.origin, new_args) + def depth_first_traversal(self, parent=None, parent_arg_index=None, depth=1): if self.args: for i, arg in enumerate(self.args): diff --git a/src/typegraph/type_utils.py b/src/typegraph/type_utils.py index ed22971..8e30ab8 100644 --- a/src/typegraph/type_utils.py +++ b/src/typegraph/type_utils.py @@ -1,19 +1,13 @@ import types import typing +import networkx as nx -from typing import ( - Union, - List, - Callable, - Any, - runtime_checkable, - Type, -) +from typing import Union, List, Callable, Any, runtime_checkable, Type, get_args from typing_extensions import get_type_hints from typing_inspect import get_generic_type -def get_origin(tp): +def get_real_origin(tp): """Get the unsubscripted version of a type. This supports generic types, Callable, Tuple, Union, Literal, Final, ClassVar @@ -37,7 +31,7 @@ def get_origin(tp): if isinstance( tp, ( - typing._BaseGenericAlias, # type: ignore + typing._BaseGenericAlias, # type: ignore typing.GenericAlias, # type: ignore typing.ParamSpecArgs, typing.ParamSpecKwargs, @@ -52,7 +46,7 @@ def get_origin(tp): def is_structural_type(tp): - if get_origin(tp): + if get_real_origin(tp): return True return False @@ -155,4 +149,40 @@ def get_subclass_types(cls: Type): if hasattr(cls, "__subclasses__"): for subclass in cls.__subclasses__(): yield subclass - yield from get_subclass_types(subclass) \ No newline at end of file + yield from get_subclass_types(subclass) + + +def get_connected_nodes(graph, node): + if node not in graph: + return set() + + # 获取正向连通的节点 + successors = set(nx.descendants(graph, node)) + + # 获取逆向连通的节点 + predecessors = set(nx.ancestors(graph, node)) + + # 合并所有连通的节点 + connected_nodes = successors | predecessors | {node} + + return connected_nodes + + +def get_connected_subgraph(graph, node): + connected_nodes = get_connected_nodes(graph, node) + subgraph = graph.subgraph(connected_nodes).copy() + return subgraph + + +def iter_type_args(tp): + args = tp.args + if args: + for arg in args: + if isinstance(arg, list): + for i in arg: + yield i + yield from iter_type_args(i) + else: + yield arg + yield from iter_type_args(arg) + diff --git a/tests/test_switch.py b/tests/test_switch.py index 1d084a0..07925c5 100644 --- a/tests/test_switch.py +++ b/tests/test_switch.py @@ -2,7 +2,7 @@ import unittest import asyncio -from typing import TypeVar +from typing import TypeVar, Generic from typegraph.converter.base import TypeConverter @@ -229,10 +229,6 @@ def __init__(self, t): class Test2(int): ... - class Test3: ... - - class Test4: ... - @t.register_converter(str, int) def str_to_int(input_value): return int(input_value) @@ -263,8 +259,8 @@ def Test_to_float(input_value): def str_to_float(input_value): return float(input_value) - @self.converter.async_register_converter(dict[K,V], dict[V,K]) - async def reverse_dict(d: dict[K,V]) -> dict[V,K]: + @self.converter.async_register_converter(dict[K, V], dict[V, K]) + async def reverse_dict(d: dict[K, V]) -> dict[V, K]: return {v: k for k, v in d.items()} @self.converter.auto_convert(localns=locals()) @@ -291,10 +287,6 @@ def test_structural(x: list[str]): def test_next_structural(x: list[dict[str, int]]): return x - @self.converter.auto_convert() - def test_structural_dict(x: list[dict[str, int]]): - return x - result = test_float_to_str("10") self.assertEqual(result, "10") @@ -317,8 +309,7 @@ def test_structural_dict(x: list[dict[str, int]]): self.assertEqual(result, ["1", "2", "3"]) result = test_next_structural([{1: "1"}, {2: "2"}, {3: "3"}]) - self.assertEqual(result, [{'1': 1}, {'2': 2}, {'3': 3}]) - + self.assertEqual(result, [{"1": 1}, {"2": 2}, {"3": 3}]) def test_auto_convert_protocol(self): from typing import Protocol, TypedDict @@ -391,6 +382,35 @@ def tests(a: list[str]): result = tests([d]) self.assertEqual(result, ["John 123 123"]) + def test_auto_convert_generic(self): + t = self.converter + + class A: ... + + class B(Generic[K, V]): ... + + @t.register_converter(list[dict[str, int]], A) + def convert_list_dict_to_a(data: list[dict[str, int]]): + return A() + + @t.register_generic_converter(dict[K, V], dict[V, K]) # type: ignore + def convert_dict(data: dict[K, V]): + return {v: k for k, v in data.items()} + + @t.register_generic_converter(dict[V, K], B[V, K]) # type: ignore + def convert_dict_to_b(data: dict[V, K]): + return B[V, K]() + + @t.auto_convert(localns=locals()) + def test_generic(a: "B[str, int]"): + return a + + result = test_generic({1: "1", 2: "2"}) + self.assertIsInstance(result, B) + + result = test_generic({"1": 1, "2": 2}) + self.assertIsInstance(result, B) + def test_async_auto_convert(self): t = self.converter @@ -555,6 +575,38 @@ async def test_async_conversion_protocol(): asyncio.run(test_async_conversion_protocol()) + def test_async_auto_convert_generic(self): + t = self.converter + + class A: ... + + class B(Generic[K, V]): ... + + @t.register_converter(list[dict[str, int]], A) + def convert_list_dict_to_a(data: list[dict[str, int]]): + return A() + + @t.register_generic_converter(dict[K, V], dict[V, K]) # type: ignore + def convert_dict(data: dict[K, V]): + return {v: k for k, v in data.items()} + + @t.register_generic_converter(dict[V, K], B[V, K]) # type: ignore + def convert_dict_to_b(data: dict[V, K]): + return B[V, K]() + + @t.auto_convert(localns=locals()) + async def test_generic(a: "B[str, int]"): + return a + + async def test_async_auto_convert_generic(): + result = await test_generic({1: "1", 2: "2"}) + self.assertIsInstance(result, B) + + result = await test_generic({"1": 1, "2": 2}) + self.assertIsInstance(result, B) + + asyncio.run(test_async_auto_convert_generic()) + if __name__ == "__main__": unittest.main()