Skip to content

Commit

Permalink
Refactor and improve support for model creation
Browse files Browse the repository at this point in the history
  • Loading branch information
Viicos committed Feb 4, 2024
1 parent a74f62c commit 0581548
Show file tree
Hide file tree
Showing 12 changed files with 342 additions and 122 deletions.
10 changes: 10 additions & 0 deletions docs/api/dynamic_stubs/rules.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,16 @@
members:
- STUB_FILES

::: django_autotyping.stubbing.codemods.model_init_overload_codemod.ModelInitOverloadCodemod
options:
show_bases: false
show_source: false
show_root_heading: true
show_root_full_path: false
show_if_no_docstring: true
members:
- STUB_FILES

::: django_autotyping.stubbing.codemods.get_model_overload_codemod.GetModelOverloadCodemod
options:
show_bases: false
Expand Down
3 changes: 2 additions & 1 deletion docs/usage/dynamic_stubs.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ python manage.py generate_stubs --local-stubs-dir typings/ --ignore DJAS001
The following is a list of the available rules related to dynamic stubs:

- [`DJAS001`][django_autotyping.stubbing.codemods.forward_relation_overload_codemod.ForwardRelationOverloadCodemod]: add overloads to the `__init__` methods of related fields.
- [`DJAS002`][django_autotyping.stubbing.codemods.create_overload_codemod.CreateOverloadCodemod]: Add overloads to methods creating an instance of a model.
- [`DJAS002`][django_autotyping.stubbing.codemods.create_overload_codemod.CreateOverloadCodemod]: Add overloads to the [`create`][django.db.models.query.QuerySet.create] method.
- [`DJAS003`][django_autotyping.stubbing.codemods.model_init_overload_codemod.ModelInitOverloadCodemod]: Add overloads to the [`Model.__init__`][django.db.models.Model] method.
- [`DJAS010`][django_autotyping.stubbing.codemods.get_model_overload_codemod.GetModelOverloadCodemod]: Add overloads to the [`apps.get_model`][django.apps.apps.get_model] method.
- [`DJAS011`][django_autotyping.stubbing.codemods.auth_functions_codemod.AuthFunctionsCodemod]: Add a custom return type to the to auth related functions.
- [`DJAS015`][django_autotyping.stubbing.codemods.reverse_overload_codemod.ReverseOverloadCodemod]: Add overloads to the [`reverse`][django.urls.reverse] function.
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ classifiers = [
dependencies = [
"django",
"libcst>=0.4.10",
"typing-extensions>=4.2.0; python_version < '3.11'",
"typing-extensions>=4.4.0; python_version < '3.12'",
]
license = {file = "LICENSE"}
dynamic = ["version", "readme"]
Expand Down
6 changes: 3 additions & 3 deletions src/django_autotyping/_compat.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import sys
from pathlib import Path

if sys.version_info >= (3, 11):
from typing import NotRequired, Required, Self, TypeAlias, TypeGuard, Unpack
if sys.version_info >= (3, 12):
from typing import NotRequired, Required, Self, TypeAlias, TypeGuard, Unpack, override
else:
from typing_extensions import NotRequired, Required, Self, TypeAlias, TypeGuard, Unpack # noqa: F401
from typing_extensions import NotRequired, Required, Self, TypeAlias, TypeGuard, Unpack, override # noqa: F401

if sys.version_info >= (3, 10):
from types import NoneType
Expand Down
5 changes: 3 additions & 2 deletions src/django_autotyping/app_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,14 @@ class StubsGenerationSettings:
This is why, by default, this configuration attribute defaults to `True`. If set to `False`,
`django-autotyping` will try its best to determine required fields, namely by checking if:
- the field can be [`null`][django.db.models.Field.null]
- the field can be [`null`][django.db.models.Field.null] or [`blank`][django.db.models.Field.null]
- the field is a primary key
- the field has a default or a database default value set
- the field is a subclass of [`DateField`][django.db.models.DateField] and has
[`auto_now`][django.db.models.DateField.auto_now] or [`auto_now_add`][django.db.models.DateField.auto_now_add]
set to `True`.
Affected rules: `DJAS002`.
Affected rules: `DJAS002`, `DJAS003`.
"""

ALLOW_REVERSE_ARGS: bool = False
Expand Down
6 changes: 4 additions & 2 deletions src/django_autotyping/stubbing/codemods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .create_overload_codemod import CreateOverloadCodemod
from .forward_relation_overload_codemod import ForwardRelationOverloadCodemod
from .get_model_overload_codemod import GetModelOverloadCodemod
from .model_init_overload_codemod import ModelInitOverloadCodemod
from .query_lookups_overload_codemod import QueryLookupsOverloadCodemod
from .reverse_overload_codemod import ReverseOverloadCodemod
from .settings_codemod import SettingCodemod
Expand All @@ -29,12 +30,13 @@
"gather_codemods",
)

RulesT: TypeAlias = Literal["DJAS001", "DJAS002", "DJAS010", "DJAS011", "DJAS015", "DJAS016"]
RulesT: TypeAlias = Literal["DJAS001", "DJAS002", "DJAS003", "DJAS010", "DJAS011", "DJAS015", "DJAS016"]

rules: list[tuple[RulesT, type[StubVisitorBasedCodemod]]] = [
("DJAS001", ForwardRelationOverloadCodemod),
("DJAS002", CreateOverloadCodemod),
# ("DJAS003", QueryLookupsOverloadCodemod),
("DJAS003", ModelInitOverloadCodemod),
# ("DJAS004", QueryLookupsOverloadCodemod),
("DJAS010", GetModelOverloadCodemod),
("DJAS011", AuthFunctionsCodemod),
("DJAS015", ReverseOverloadCodemod),
Expand Down
216 changes: 216 additions & 0 deletions src/django_autotyping/stubbing/codemods/_model_creation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
from abc import ABC, abstractmethod
from typing import ClassVar, TypedDict, cast

import libcst as cst
from django.db.models import (
AutoField,
BooleanField,
CharField,
DateField,
DateTimeField,
DecimalField,
Field,
FloatField,
GenericIPAddressField,
IntegerField,
IPAddressField,
TextField,
TimeField,
UUIDField,
)
from django.db.models.fields.reverse_related import ForeignObjectRel
from libcst import helpers
from libcst.codemod import CodemodContext
from libcst.codemod.visitors import AddImportsVisitor, ImportItem
from libcst.metadata import ScopeProvider

from django_autotyping._compat import Required
from django_autotyping.typing import FlattenFunctionDef

from ._utils import TypedDictAttribute, build_typed_dict, get_param
from .base import InsertAfterImportsVisitor, StubVisitorBasedCodemod
from .constants import OVERLOAD_DECORATOR


class FieldType(TypedDict):
type: Required[str]
"""The stringified type annotation to be used."""

typing_imports: list[str]
"""A list of typing objects to be imported."""

extra_imports: list[ImportItem]
"""A list of extra import items to be included."""


# This types are taken from `django-stubs`
# NOTE: Order matters! This dict is iterated in order to match field classes
# against the keys. Be sure to define the most specific subclasses first
# (e.g. `AutoField` is a subclass of `IntegerField`, so it is defined first).
# NOTE: Maybe `get_args(field_instance.__orig_class__)` could be used to take into
# account explicit parametrization.
FIELD_SET_TYPES_MAP: dict[type[Field], FieldType] = {
AutoField: {
"type": "Combinable | int | str",
},
IntegerField: {"type": "float | int | str | Combinable"},
FloatField: {"type": "float | int | str | Combinable"},
DecimalField: {"type": "str | float | Decimal | Combinable", "extra_imports": [ImportItem("decimal", "Decimal")]},
CharField: {"type": "str | int | Combinable"}, # TODO this and textfield seems to allow `SupportsStr`
TextField: {"type": "str | Combinable"},
BooleanField: {"type": "bool | Combinable"},
IPAddressField: {"type": "str | Combinable"},
GenericIPAddressField: {
"type": "str | int | Callable[..., Any] | Combinable", # TODO, Callable, really?
"typing_imports": ["Any", "Callable"],
},
DateTimeField: {
"type": "str | datetime | date | Combinable",
"extra_imports": [ImportItem("datetime", "date"), ImportItem("datetime", "datetime")],
},
DateField: {
"type": "str | date | Combinable",
"extra_imports": [ImportItem("datetime", "date")],
},
TimeField: {
"type": "str | time | datetime | Combinable",
"extra_imports": [ImportItem("datetime", "time"), ImportItem("datetime", "datetime")],
},
UUIDField: {"type": "str | UUID", "extra_imports": [ImportItem("uuid", "UUID")]},
Field: {"type": "Any", "typing_imports": ["Any"]},
}
"""A mapping of field classes to the types they allow to be set to."""


class ModelCreationBaseCodemod(StubVisitorBasedCodemod, ABC):
"""A base codemod that can be used to add overloads for model creation.
Useful for: `Model.__init__`, `BaseManager.create`.
"""

METADATA_DEPENDENCIES = {ScopeProvider}
KWARGS_TYPED_DICT_NAME: ClassVar[str]
"""A templated string to render the name of the `TypedDict` for the `**kwargs` annotation.
Should contain the template `{model_name}`.
"""

def __init__(self, context: CodemodContext) -> None:
super().__init__(context)
self.add_model_imports()

model_typed_dicts = self.build_model_kwargs()
InsertAfterImportsVisitor.insert_after_imports(context, model_typed_dicts)

AddImportsVisitor.add_needed_import(
self.context,
module="django.db.models.expressions",
obj="Combinable",
)

# Even though these are most likely included, we import them for safety:
self.add_typing_imports(["TypedDict", "TypeVar", "Unpack", "overload"])

def build_model_kwargs(self) -> list[cst.ClassDef]:
"""Return a list of class definition representing the typed dicts to be used for overloads."""

contenttypes_installed = self.django_context.apps.is_installed("django.contrib.contenttypes")
if contenttypes_installed:
from django.contrib.contenttypes.fields import GenericForeignKey
all_optional = self.stubs_settings.MODEL_FIELDS_OPTIONAL

class_defs: list[cst.ClassDef] = []

for model in self.django_context.models:
model_name = self.django_context.get_model_name(model)

# This mostly follows the implementation of the Django's `Model.__init__` method:
typed_dict_attributes = []
for field in cast(list[Field], model._meta.fields):
if isinstance(field.remote_field, ForeignObjectRel):
# TODO support for attname as well (i.e. my_foreign_field_id).
# Issue is if this is a required field, we can't make both required at the same time
attr_name = field.name
annotation = self.django_context.get_model_name(
# As per `ForwardManyToOneDescriptor.__set__`:
field.remote_field.model._meta.concrete_model
)
annotation += " | Combinable"
elif contenttypes_installed and isinstance(field, GenericForeignKey):
# it's generic, so cannot set specific model
attr_name = field.name
annotation = "Any"
self.add_typing_imports(["Any"])
else:
attr_name = field.attname
# Regular fields:
field_set_type = next(
(v for k, v in FIELD_SET_TYPES_MAP.items() if issubclass(type(field), k)),
FieldType(type="Any", typing_imports=["Any"]),
)

self.add_typing_imports(field_set_type.get("typing_imports", []))
if extra_imports := field_set_type.get("extra_imports"):
imports = AddImportsVisitor._get_imports_from_context(self.context)
imports.extend(extra_imports)
self.context.scratch[AddImportsVisitor.CONTEXT_KEY] = imports

annotation = field_set_type["type"]

if not isinstance(field, GenericForeignKey) and self.django_context.is_nullable_field(field):
annotation += " | None"

typed_dict_attributes.append(
TypedDictAttribute(
attr_name,
annotation=annotation,
docstring=getattr(field, "help_text", None) or None,
required=not all_optional and not self.django_context.is_required_field(field),
)
)

class_defs.append(
build_typed_dict(
self.KWARGS_TYPED_DICT_NAME.format(model_name=model_name),
attributes=typed_dict_attributes,
total=False,
leading_line=True,
)
)

return class_defs

def mutate_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> FlattenFunctionDef:
class_name = self.get_metadata(ScopeProvider, original_node).name

overload = updated_node.with_changes(decorators=[OVERLOAD_DECORATOR])
overloads: list[cst.FunctionDef] = []

for model in self.django_context.models:
model_name = self.django_context.get_model_name(model)

# sets `self: BaseManager[model_name]/_QuerySet[model_name, _Row]/model_name`
annotation = self.get_self_annotation(model_name, class_name)
self_param = get_param(overload, "self")
overload_ = overload.with_deep_changes(
old_node=self_param,
annotation=cst.Annotation(annotation),
)

overload_ = overload_.with_deep_changes(
old_node=overload_.params.star_kwarg,
annotation=cst.Annotation(
annotation=helpers.parse_template_expression(
f"Unpack[{self.KWARGS_TYPED_DICT_NAME}]".format(model_name=model_name)
)
),
)

overloads.append(overload_)

return cst.FlattenSentinel(overloads)

@abstractmethod
def get_self_annotation(self, model_name: str, class_name: str) -> cst.BaseExpression:
"""Return the annotation to be set on the `self` parameter."""
pass
4 changes: 3 additions & 1 deletion src/django_autotyping/stubbing/codemods/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from abc import ABC
from typing import TYPE_CHECKING, ClassVar, Sequence, TypeVar, cast

import libcst as cst
Expand Down Expand Up @@ -56,10 +57,11 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c
)


class StubVisitorBasedCodemod(VisitorBasedCodemodCommand):
class StubVisitorBasedCodemod(VisitorBasedCodemodCommand, ABC):
"""The base class for all codemods used for custom stub files."""

STUB_FILES: ClassVar[set[str]]
"""A set of stub files the codemod should apply to."""

def __init__(self, context: CodemodContext) -> None:
super().__init__(context)
Expand Down
Loading

0 comments on commit 0581548

Please sign in to comment.