Skip to content

Commit

Permalink
Issue #35 & #36 added variance parsing api_analyzer; added variance a…
Browse files Browse the repository at this point in the history
…nd constraints to stubs generator
  • Loading branch information
Masara committed Nov 20, 2023
1 parent 9eb19f5 commit 69d421f
Show file tree
Hide file tree
Showing 14 changed files with 622 additions and 16 deletions.
2 changes: 2 additions & 0 deletions src/safeds_stubgen/api_analyzer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
ParameterAssignment,
QualifiedImport,
Result,
VarianceType,
WildcardImport,
)
from ._get_api import get_api
Expand Down Expand Up @@ -64,5 +65,6 @@
"SetType",
"TupleType",
"UnionType",
"VarianceType",
"WildcardImport",
]
22 changes: 22 additions & 0 deletions src/safeds_stubgen/api_analyzer/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ class Class:
attributes: list[Attribute] = field(default_factory=list)
methods: list[Function] = field(default_factory=list)
classes: list[Class] = field(default_factory=list)
variances: list[Variance] = field(default_factory=list)

def to_dict(self) -> dict[str, Any]:
return {
Expand All @@ -187,6 +188,7 @@ def to_dict(self) -> dict[str, Any]:
"attributes": [attribute.id for attribute in self.attributes],
"methods": [method.id for method in self.methods],
"classes": [class_.id for class_ in self.classes],
"variances": [variance.to_dict() for variance in self.variances],
}

def add_method(self, method: Function) -> None:
Expand Down Expand Up @@ -303,6 +305,26 @@ class ParameterAssignment(PythonEnum):
NAMED_VARARG = "NAMED_VARARG"


@dataclass(frozen=True)
class Variance:
name: str
type: AbstractType
variance_type: VarianceType

def to_dict(self):
return {
"name": self.name,
"type": self.type.to_dict(),
"variance_type": self.variance_type.name
}


class VarianceType(PythonEnum):
CONTRAVARIANT = "CONTRAVARIANT"
COVARIANT = "COVARIANT"
INVARIANT = "INVARIANT"


@dataclass(frozen=True)
class Result:
id: str
Expand Down
39 changes: 39 additions & 0 deletions src/safeds_stubgen/api_analyzer/_ast_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
Parameter,
QualifiedImport,
Result,
Variance,
VarianceType,
WildcardImport,
)
from ._mypy_helpers import (
Expand All @@ -28,6 +30,7 @@
get_funcdef_definitions,
get_mypyfile_definitions,
mypy_type_to_abstract_type,
mypy_variance_parser,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -120,6 +123,41 @@ def enter_classdef(self, node: mp_nodes.ClassDef) -> None:
# Get docstring
docstring = self.docstring_parser.get_class_documentation(node)

# Variance
# Special base classes like Generic[...] get moved to "removed_base_type_expr" during semantic analysis of mypy
generic_exprs = [
removed_base_type_expr
for removed_base_type_expr in node.removed_base_type_exprs
if removed_base_type_expr.base.name == "Generic"
]
variances = []
if generic_exprs:
# Can only be one, since a class can inherit "Generic" only one time
generic_expr = generic_exprs[0].index

if isinstance(generic_expr, mp_nodes.TupleExpr):
generic_types = [item.node for item in generic_expr.items]
elif isinstance(generic_expr, mp_nodes.NameExpr):
generic_types = [generic_expr.node]
else: # pragma: no cover
raise TypeError("Unexpected type while parsing generic type.")

for generic_type in generic_types:
variance_type = mypy_variance_parser(generic_type.variance)
if variance_type == VarianceType.INVARIANT:
variance_values = sds_types.UnionType([
mypy_type_to_abstract_type(value)
for value in generic_type.values
])
else:
variance_values = mypy_type_to_abstract_type(generic_type.upper_bound)

variances.append(Variance(
name=generic_type.name,
type=variance_values,
variance_type=variance_type
))

# superclasses
# Todo Aliasing: Werden noch nicht aufgelöst
superclasses = [superclass.fullname for superclass in node.base_type_exprs if hasattr(superclass, "fullname")]
Expand All @@ -144,6 +182,7 @@ def enter_classdef(self, node: mp_nodes.ClassDef) -> None:
docstring=docstring,
reexported_by=reexported_by,
constructor_fulldocstring=constructor_fulldocstring,
variances=variances
)
self.__declaration_stack.append(class_)

Expand Down
18 changes: 16 additions & 2 deletions src/safeds_stubgen/api_analyzer/_mypy_helpers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Literal

import mypy.types as mp_types
from mypy import nodes as mp_nodes
Expand All @@ -9,7 +9,7 @@

import safeds_stubgen.api_analyzer._types as sds_types

from ._api import ParameterAssignment
from ._api import ParameterAssignment, VarianceType

if TYPE_CHECKING:
from mypy.nodes import ClassDef, FuncDef, MypyFile
Expand Down Expand Up @@ -57,6 +57,8 @@ def mypy_type_to_abstract_type(mypy_type: Instance | ProperType | MypyType) -> A
return sds_types.NamedType(name="Any")
elif isinstance(mypy_type, mp_types.NoneType):
return sds_types.NamedType(name="None")
elif isinstance(mypy_type, mp_types.LiteralType):
return sds_types.LiteralType(literal=mypy_type.value)
elif isinstance(mypy_type, mp_types.UnboundType):
if mypy_type.name == "list":
return sds_types.ListType(types=[
Expand Down Expand Up @@ -141,3 +143,15 @@ def find_return_stmts_recursive(stmts: list[mp_nodes.Statement]) -> list[mp_node
return_stmts.append(stmt)

return return_stmts


def mypy_variance_parser(mypy_variance_type: Literal[0, 1, 2]) -> VarianceType:
match mypy_variance_type:
case 0:
return VarianceType.INVARIANT
case 1:
return VarianceType.COVARIANT
case 2:
return VarianceType.CONTRAVARIANT
case _: # pragma: no cover
raise ValueError("Mypy variance parser received an illegal parameter value.")
9 changes: 4 additions & 5 deletions src/safeds_stubgen/api_analyzer/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,18 +349,17 @@ def __hash__(self) -> int:

@dataclass(frozen=True)
class LiteralType(AbstractType):
literals: list[str | int | float | bool]
literal: str | int | float | bool

@classmethod
def from_dict(cls, d: dict[str, Any]) -> LiteralType:
literals = list(d["literals"])
return LiteralType(literals)
return LiteralType(d["literal"])

def to_dict(self) -> dict[str, Any]:
return {"kind": self.__class__.__name__, "literals": self.literals}
return {"kind": self.__class__.__name__, "literal": self.literal}

def __hash__(self) -> int:
return hash(frozenset(self.literals))
return hash(frozenset([self.literal]))


@dataclass(frozen=True)
Expand Down
48 changes: 45 additions & 3 deletions src/safeds_stubgen/stubs_generator/_generate_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import string
from pathlib import Path
from typing import Generator
from typing import TYPE_CHECKING

from safeds_stubgen.api_analyzer import (
API,
Expand All @@ -14,9 +14,13 @@
ParameterAssignment,
QualifiedImport,
Result,
VarianceType,
WildcardImport,
)

if TYPE_CHECKING:
from collections.abc import Generator


def generate_stubs(api: API, out_path: Path) -> None:
"""Generate Safe-DS stubs.
Expand Down Expand Up @@ -151,10 +155,48 @@ def _create_class_string(self, class_: Class, class_indentation: str = "") -> st
if len(superclasses) > 1:
self._current_todo_msgs.add("multiple_inheritance")

# Variance & Constrains
constraints_info = ""
variance_info = ""
if class_.variances:
constraints = []
variances = []
for variance in class_.variances:
match variance.variance_type.name:
case VarianceType.INVARIANT.name:
variance_inheritance = ""
variance_direction = ""
case VarianceType.COVARIANT.name:
variance_inheritance = "sub"
variance_direction = "out "
case VarianceType.CONTRAVARIANT.name:
variance_inheritance = "super"
variance_direction = "in "
case _: # pragma: no cover
raise ValueError(f"Expected variance kind, got {variance.variance_type.name}.")

# Convert name to camelCase and check for keywords
variance_name_camel_case = _convert_snake_to_camel_case(variance.name)
variance_name_camel_case = self._replace_if_safeds_keyword(variance_name_camel_case)

variances.append(f"{variance_direction}{variance_name_camel_case}")
if variance_inheritance:
constraints.append(
f"{variance_name_camel_case} {variance_inheritance} "
f"{self._create_type_string(variance.type.to_dict())}"
)

if variances:
variance_info = f"<{', '.join(variances)}>"

if constraints:
constraints_info_inner = f",\n{inner_indentations}".join(constraints)
constraints_info = f" where {{\n{inner_indentations}{constraints_info_inner}\n}}"

# Class signature line
class_signature = (
f"{class_indentation}{self._create_todo_msg(class_indentation)}class "
f"{class_.name}({parameter_info}){superclass_info}"
f"{class_.name}{variance_info}({parameter_info}){superclass_info}{constraints_info}"
)

# Attributes
Expand Down Expand Up @@ -524,7 +566,7 @@ def _create_type_string(self, type_data: dict | None) -> str:
return f"Map<{key_data}>"
return "Map"
elif kind == "LiteralType":
return f"literal<{', '.join(type_data['literals'])}>"
return f"literal<{type_data['literal']}>"

raise ValueError(f"Unexpected type: {kind}")

Expand Down
13 changes: 13 additions & 0 deletions tests/data/test_package/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,3 +174,16 @@ def infer_function(infer_param=1, infer_param_2: int = "Something"):
return SomeClass

return int


_T_co = TypeVar("_T_co", covariant=True, bound=str)
_T_con = TypeVar("_T_con", contravariant=True, bound=SomeClass)
_T_in = TypeVar("_T_in", int, Literal[1, 2])


class VarianceClassAll(Generic[_T_co, _T_con, _T_in]):
...


class VarianceClassOnlyInvariance(Generic[_T_in]):
...
18 changes: 18 additions & 0 deletions tests/data/test_stub_generation/variance_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from typing import Generic, TypeVar, Literal


class A:
...


_T_co = TypeVar("_T_co", covariant=True, bound=str)
_T_con = TypeVar("_T_con", contravariant=True, bound=A)
_T_in = TypeVar("_T_in", int, Literal[1, 2])


class VarianceClassAll(Generic[_T_co, _T_con, _T_in]):
...


class VarianceClassOnlyInvariance(Generic[_T_in]):
...
Loading

0 comments on commit 69d421f

Please sign in to comment.