Skip to content

Commit

Permalink
v0.1.2rc2
Browse files Browse the repository at this point in the history
  • Loading branch information
luxuncang committed May 31, 2024
1 parent a781973 commit 3f9dfe6
Show file tree
Hide file tree
Showing 5 changed files with 221 additions and 46 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "typegraph3"
version = "0.1.2rc1"
version = "0.1.2rc2"
description = "Type Auto Switch"
authors = [
{name = "luxuncang", email = "luxuncang@qq.com"},
Expand Down
103 changes: 84 additions & 19 deletions src/typegraph/converter/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,24 @@
Set,
Iterable,
Iterator,
Sequence,
MutableSequence,
cast,
Type,
Awaitable,
Any,
Optional,
Mapping,
MutableMapping,
get_type_hints,
Generic,
is_typeddict,
)
from functools import wraps, reduce

import networkx as nx
from typing_extensions import get_args, get_origin
from typing_inspect import is_union_type, get_generic_type
from typeguard import check_type, TypeCheckError, CollectionCheckStrategy
from typeguard import check_type, CollectionCheckStrategy

from .typevar import iter_deep_type, gen_typevar_model, extract_typevar_mapping
from ..type_utils import (
Expand All @@ -32,6 +36,8 @@
is_protocol_type,
check_protocol_type,
get_subclass_types,
get_connected_subgraph,
iter_type_args,
)


Expand All @@ -50,6 +56,7 @@ def __init__(self):
self.pG = nx.DiGraph()
self.tG = nx.DiGraph()
self.pmG = nx.DiGraph()
self.qG = nx.DiGraph()
TypeConverter.instances.append(self)

def get_graph(
Expand Down Expand Up @@ -96,8 +103,6 @@ def _gen_edge(
metadata={"protocol": True},
)
for p_type in self.get_protocol_types(out_type):
if out_type == str:
print(p_type, in_type, out_type, converter)
self.pG.add_edge(
out_type,
p_type,
Expand All @@ -110,6 +115,11 @@ def _gen_graph(self, in_type: Type[In], out_type: Type[Out]):
tmp_G = nx.DiGraph()
im = gen_typevar_model(in_type)
om = gen_typevar_model(out_type)

def _gen_sub_graph(mapping, node):
for su, sv, sc in get_connected_subgraph(self.tG, node).edges(data=True):
tmp_G.add_edge(su.get_instance(mapping), sv.get_instance(mapping), **sc)

for u, v, c in self.tG.edges(data=True):
um = gen_typevar_model(u)
vm = gen_typevar_model(v)
Expand All @@ -120,8 +130,27 @@ def _gen_graph(self, in_type: Type[In], out_type: Type[Out]):
tmp_G.add_edge(
um.get_instance(mapping), vm.get_instance(mapping), **c
)
_gen_sub_graph(mapping, t)
except Exception:
...

for su, sv, sc in self.G.edges(data=True):
su_m = gen_typevar_model(su)
for arg in iter_type_args(su_m):
try:
mapping = extract_typevar_mapping(um, arg)
sub_m = su_m.replace_args(
arg,
gen_typevar_model(vm.get_instance(mapping)), # type: ignore
).get_instance()
tmp_G.add_edge(
sub_m,
sv,
**sc,
)
except Exception:
...
self.qG = nx.compose(self.qG, tmp_G)
return tmp_G

def register_generic_converter(self, input_type: Type, out_type: Type):
Expand Down Expand Up @@ -333,6 +362,8 @@ def __iter_func_dict(item):
res = map(_iter_func, input_value)
elif in_origin == dict or out_origin == Dict:
res = dict(map(__iter_func_dict, input_value.items()))
elif out_origin in (Mapping, MutableMapping):
res = dict(map(__iter_func_dict, input_value.items()))
else:
raise ValueError(
f"Unsupported structural_type {input_type} to {out_type}"
Expand Down Expand Up @@ -412,6 +443,11 @@ async def __iter_func_dict(item):
*map(__iter_func_dict, input_value.items())
)
res = dict(items)
elif out_origin in (Mapping, MutableMapping):
items = await asyncio.gather(
*map(__iter_func_dict, input_value.items())
)
res = dict(items)
else:
raise ValueError(
f"Unsupported structural_type {input_type} to {out_type}"
Expand Down Expand Up @@ -497,17 +533,38 @@ def get_edges(self, sub_class: bool = False, protocol: bool = False):
):
yield edge

def show_mermaid_graph(self, sub_class: bool = False, protocol: bool = False):
def show_mermaid_graph(
self, sub_class: bool = False, protocol: bool = False, full: bool = False
):
from IPython.display import display, Markdown
import typing

nodes = []

def get_name(cls):
if type(cls) in (typing._GenericAlias, typing.GenericAlias): # type: ignore
return str(cls)
elif hasattr(cls, "__name__"):
return cls.__name__
return str(cls)

def get_node_name(cls):
return f"node{nodes.index(cls)}"

text = "```mermaid\ngraph TD;\n"
for edge in self.get_edges(sub_class=sub_class, protocol=protocol):
line_style = "--" if edge[2]["line"] else "-.-"
text += f"{edge[0].__name__}{line_style}>{edge[1].__name__}\n"
for edge in self.get_graph(
sub_class=sub_class, protocol=protocol, combos=[self.qG] if full else None
).edges(data=True):
if edge[0] not in nodes:
nodes.append(edge[0])
if edge[1] not in nodes:
nodes.append(edge[1])
line_style = "--" if edge[2].get("line", False) else "-.-"
text += f'{get_node_name(edge[0])}["{get_name(edge[0])}"] {line_style}> {get_node_name(edge[1])}["{get_name(edge[1])}"]\n'
text += "```"

display(Markdown(text))
return text
# return text

def get_all_paths(
self,
Expand Down Expand Up @@ -566,13 +623,21 @@ def get_check_types_by_value(
for edge in G.edges():
if edge[0] in nodes:
continue
nodes.add(edge[0])
try:
check_type(
input_value,
edge[0],
collection_check_strategy=CollectionCheckStrategy.ALL_ITEMS,
)
except Exception:
continue
yield edge[0]
if get_origin(edge[0]) in (
Iterable,
Iterator,
Mapping,
MutableMapping,
Sequence,
MutableSequence,
) or is_typeddict(edge[0]):
nodes.add(edge[0])
try:
check_type(
input_value,
edge[0],
collection_check_strategy=CollectionCheckStrategy.ALL_ITEMS,
)
except Exception:
continue
yield edge[0]
30 changes: 29 additions & 1 deletion src/typegraph/converter/typevar.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import annotations
from typing import (
TypeVar,
List,
Expand All @@ -12,7 +13,7 @@
from typing_inspect import is_typevar

from ..type_utils import (
get_origin as get_real_origin,
get_real_origin,
generate_type,
)

Expand Down Expand Up @@ -54,6 +55,33 @@ def get_instance(self, instance: Optional[dict[Type[TypeVar], Type]] = None):
raise ValueError("Invalid TypeVarModel")
return generate_type(generic, args_list)

def replace_args(self, source: TypeVarModel, target: TypeVarModel) -> 'TypeVarModel':
if self.args is None:
return TypeVarModel(self.origin)

new_args = []
for arg in self.args:
if isinstance(arg, TypeVarModel):
if arg == source:
new_args.append(target)
else:
new_args.append(arg.replace_args(source, target))
elif isinstance(arg, list):
new_list = []
for item in arg:
if isinstance(item, TypeVarModel):
if item == source:
new_list.append(target)
else:
new_list.append(item.replace_args(source, target))
else:
new_list.append(item)
new_args.append(new_list)
else:
new_args.append(arg)

return TypeVarModel(self.origin, new_args)

def depth_first_traversal(self, parent=None, parent_arg_index=None, depth=1):
if self.args:
for i, arg in enumerate(self.args):
Expand Down
54 changes: 42 additions & 12 deletions src/typegraph/type_utils.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,13 @@
import types
import typing
import networkx as nx

from typing import (
Union,
List,
Callable,
Any,
runtime_checkable,
Type,
)
from typing import Union, List, Callable, Any, runtime_checkable, Type, get_args
from typing_extensions import get_type_hints
from typing_inspect import get_generic_type


def get_origin(tp):
def get_real_origin(tp):
"""Get the unsubscripted version of a type.
This supports generic types, Callable, Tuple, Union, Literal, Final, ClassVar
Expand All @@ -37,7 +31,7 @@ def get_origin(tp):
if isinstance(
tp,
(
typing._BaseGenericAlias, # type: ignore
typing._BaseGenericAlias, # type: ignore
typing.GenericAlias, # type: ignore
typing.ParamSpecArgs,
typing.ParamSpecKwargs,
Expand All @@ -52,7 +46,7 @@ def get_origin(tp):


def is_structural_type(tp):
if get_origin(tp):
if get_real_origin(tp):
return True
return False

Expand Down Expand Up @@ -155,4 +149,40 @@ def get_subclass_types(cls: Type):
if hasattr(cls, "__subclasses__"):
for subclass in cls.__subclasses__():
yield subclass
yield from get_subclass_types(subclass)
yield from get_subclass_types(subclass)


def get_connected_nodes(graph, node):
if node not in graph:
return set()

# 获取正向连通的节点
successors = set(nx.descendants(graph, node))

# 获取逆向连通的节点
predecessors = set(nx.ancestors(graph, node))

# 合并所有连通的节点
connected_nodes = successors | predecessors | {node}

return connected_nodes


def get_connected_subgraph(graph, node):
connected_nodes = get_connected_nodes(graph, node)
subgraph = graph.subgraph(connected_nodes).copy()
return subgraph


def iter_type_args(tp):
args = tp.args
if args:
for arg in args:
if isinstance(arg, list):
for i in arg:
yield i
yield from iter_type_args(i)
else:
yield arg
yield from iter_type_args(arg)

Loading

0 comments on commit 3f9dfe6

Please sign in to comment.