Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(python): support wrapped aliases in pydantic v2 #5102

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
43 changes: 26 additions & 17 deletions .github/workflows/publish-generator-fastapi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@ on:
push:
branches:
- main
paths:
- "generators/python/fastapi/versions.yml"
- dsinghvi/v2-wrapped-aliases
workflow_dispatch:
inputs:
version:
Expand Down Expand Up @@ -53,28 +52,38 @@ jobs:

# Manually publish and register generators
# The logic is identical to the step above, but could not find a good way to work with the matrix AND manual trigger in one job
publish-manually:
publish-fastapi:
runs-on: ubuntu-latest
# Only run this job if this has been triggered manually
if: ${{ github.event_name == 'workflow_dispatch' }}
steps:
- name: Checkout repo at current ref
- name: Checkout repo
uses: actions/checkout@v4
with:
ref: main
fetch-depth: 2
fetch-depth: 0

- name: Install Poetry
uses: snok/install-poetry@v1

- name: Install
uses: ./.github/actions/install
working-directory: ./generators/python
run: poetry install

- name: Log in to Docker Hub
uses: docker/login-action@v3
with:
username: fernapi
password: ${{ secrets.FERN_API_DOCKERHUB_PASSWORD }}

- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v1

- name: Run publish
env:
FERN_TOKEN: ${{ secrets.FERN_TOKEN }}
DOCKER_USERNAME: fernapi
DOCKER_PASSWORD: ${{ secrets.FERN_API_DOCKERHUB_PASSWORD }}
run: |
pnpm seed:local publish generator fastapi --ver ${{ inputs.version }} --log-level debug
pnpm seed:local register generator --generators fastapi
- name: Build and push Docker image
uses: docker/build-push-action@v2
with:
context: ./generators/python
file: ./generators/python/fastapi/Dockerfile
platforms: linux/amd64,linux/arm64
cache-from: type=gha
cache-to: type=gha,mode=min
push: true
labels: version=1.6.0-rc5
tags: fernapi/fern-fastapi-server:1.6.0-rc5
14 changes: 14 additions & 0 deletions generators/python/fastapi/versions.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,18 @@
# For unreleased changes, use unreleased.yml
- version: 1.6.0-rc1
irVersion: 53
changelogEntry:
- type: fix
summary: |
Support wrapped aliases in Pydantic V2.

- version: 1.6.0-rc0
irVersion: 53
changelogEntry:
- type: fix
summary: |
Support wrapped aliases in Pydantic V2.

- version: 1.5.0
irVersion: 53
changelogEntry:
Expand Down
13 changes: 8 additions & 5 deletions generators/python/src/fern_python/cli/abstract_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,14 @@ def generate_project(
if output_mode_union.type == "downloadFiles":
# since download files does not contain a pyproject.toml
# we run ruff using the fern_python poetry.toml (copied into the docker)
publisher._run_command(
command=["poetry", "run", "ruff", "format", "/fern/output"],
safe_command="poetry run ruff format /fern/output",
cwd="/",
)
try:
publisher._run_command(
command=["poetry", "run", "ruff", "format", "/fern/output"],
safe_command="poetry run ruff format /fern/output",
cwd="/",
)
except:
pass
elif output_mode_union.type == "github":
publisher.run_poetry_install()
publisher.run_ruff_format()
Expand Down
13 changes: 8 additions & 5 deletions generators/python/src/fern_python/cli/publisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,14 @@ def __init__(
self._generator_config = generator_config

def run_ruff_format(self) -> None:
self._run_command(
command=["poetry", "run", "ruff", "format", "--cache-dir", "../.ruffcache"],
safe_command="poetry run ruff format",
cwd=None,
)
try:
self._run_command(
command=["poetry", "run", "ruff", "format", "--cache-dir", "../.ruffcache"],
safe_command="poetry run ruff format",
cwd=None,
)
except:
pass

def run_poetry_install(self) -> None:
self._run_command(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Optional
from typing import Any, Optional

from .module import Module
from .qualfied_name import QualifiedName
Expand Down Expand Up @@ -70,3 +70,5 @@ class Reference:
This was originally added to ensure the use of NotRequired for TypedDicts
brings in this import.
"""

generic: Optional[Any] = None
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def write_node(self, node: AST.AstNode, should_write_as_snippet: Optional[bool]
node.write(writer=self, should_write_as_snippet=should_write_as_snippet)

def write_reference(self, reference: AST.Reference) -> None:
self.write(self._reference_resolver.resolve_reference(reference))
self.write(self._reference_resolver.resolve_reference(reference, self))

def should_format_as_snippet(self) -> bool:
return self._should_format_as_snippet
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@

class ReferenceResolver(ABC):
@abstractmethod
def resolve_reference(self, reference: AST.Reference) -> str:
def resolve_reference(self, reference: AST.Reference, writer: AST.NodeWriter) -> str:
...
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from ordered_set import OrderedSet

from fern_python.codegen.node_writer_impl import NodeWriterImpl

from . import AST
from .reference_resolver import ReferenceResolver

Expand Down Expand Up @@ -76,7 +78,7 @@ def resolve_references(self) -> None:
prefix_for_qualfied_names=prefix_for_qualfied_names,
)

def resolve_reference(self, reference: AST.Reference) -> str:
def resolve_reference(self, reference: AST.Reference, writer: AST.NodeWriter) -> str:
if self._original_import_to_resolved_import is None:
raise RuntimeError("References have not yet been resolved.")

Expand All @@ -97,6 +99,20 @@ def resolve_reference(self, reference: AST.Reference) -> str:
)
)

if reference.generic is not None:
temp_writer = NodeWriterImpl(
should_format=False,
should_format_as_snippet=False,
whitelabel=False,
should_include_header=False,
reference_resolver=self,
)

resolved_reference += "["
reference.generic.write(writer=temp_writer)
resolved_reference += temp_writer.to_str()
resolved_reference += "]"

# Here we string-reference a type reference if the import is marked for `if TYPE_CHECKING` or if the import
# is deferred until after the current declaration (e.g. for circular references when defining Pydantic models).
return (
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import enum
from typing import Optional, Tuple

from fern_python.codegen import AST

Expand All @@ -14,44 +15,47 @@ class PydanticVersionCompatibility(str, enum.Enum):
Both = "both"


def _export(*export: str) -> AST.ClassReference:
def _export(export: Tuple[str, ...], generic: Optional[AST.TypeHint] = None) -> AST.ClassReference:
return AST.ClassReference(
import_=AST.ReferenceImport(
module=AST.Module.external(
dependency=PYDANTIC_DEPENDENCY,
module_path=("pydantic",),
),
),
qualified_name_excluding_import=export,
import_=Pydantic.PydanticImport(), qualified_name_excluding_import=export, generic=generic
)


class Pydantic:
@staticmethod
def Field() -> AST.ClassReference:
return _export("Field")
return _export(("Field",))

@staticmethod
def BaseModel() -> AST.ClassReference:
return _export("BaseModel")
return _export(("BaseModel",))

@staticmethod
def ConfigDict() -> AST.ClassReference:
return _export("ConfigDict")
return _export(("ConfigDict",))

@staticmethod
def PrivateAttr() -> AST.ClassReference:
return _export("PrivateAttr")
return _export(("PrivateAttr",))

@staticmethod
def RootModel(generic: Optional[AST.TypeHint] = None) -> AST.ClassReference:
return _export(("RootModel",), generic)

@staticmethod
def RootModel() -> AST.ClassReference:
return _export("RootModel")
def PydanticImport() -> AST.ReferenceImport:
return AST.ReferenceImport(
module=AST.Module.external(
dependency=PYDANTIC_DEPENDENCY,
module_path=("pydantic",),
),
)

class Extra:
@staticmethod
def forbid() -> AST.Expression:
return AST.Expression(_export("Extra", "forbid"))
return AST.Expression(_export(("Extra", "forbid")))

@staticmethod
def allow() -> AST.Expression:
return AST.Expression(_export("Extra", "allow"))
return AST.Expression(_export(("Extra", "allow")))
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ def check_wrapped_aliases_v1_only(self) -> Self:
version_compat = self.version
use_wrapped_aliases = self.wrapped_aliases

if use_wrapped_aliases and version_compat != PydanticVersionCompatibility.V1:
if use_wrapped_aliases and version_compat == PydanticVersionCompatibility.Both:
raise ValueError(
"Wrapped aliases are only supported in Pydantic V1, please update your `version` field to be 'v1' to continue using wrapped aliases."
"Wrapped aliases are not supported for `both`, please update your `version` field to be 'v1' or `v2` to continue using wrapped aliases."
)

if self.enum_type != "literals":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
import fern.ir.resources as ir_types

from fern_python.codegen import AST, LocalClassReference, SourceFile
from fern_python.external_dependencies.pydantic import PydanticVersionCompatibility
from fern_python.external_dependencies.pydantic import Pydantic
from fern_python.pydantic_codegen import PydanticField, PydanticModel

from ..context import PydanticGeneratorContext
from .custom_config import PydanticModelCustomConfig
from .validators import (
PydanticV1CustomRootTypeValidatorsGenerator,
PydanticCustomRootTypeValidatorsGenerator,
PydanticValidatorsGenerator,
ValidatorsGenerator,
)
Expand Down Expand Up @@ -70,6 +70,8 @@ def __init__(

self._model_contains_forward_refs = False

self._is_pydantic_v2_only = self._custom_config.version == "v2"

models_to_extend = [item for item in base_models] if base_models is not None else []
extends_crs = (
[context.get_class_reference_for_type_id(extended.type_id, as_request=False) for extended in extends]
Expand All @@ -95,7 +97,9 @@ def __init__(
require_optional_fields=custom_config.require_optional_fields,
is_pydantic_v2=self._context.core_utilities.get_is_pydantic_v2(),
universal_field_validator=self._context.core_utilities.universal_field_validator,
universal_root_validator=self._context.core_utilities.universal_root_validator,
universal_root_validator=self._model_validator
if self._is_pydantic_v2_only
else self._context.core_utilities.universal_root_validator,
is_root_model=is_root_model,
update_forward_ref_function_reference=self._context.core_utilities.get_update_forward_refs(),
field_metadata_getter=lambda: self._context.core_utilities.get_field_metadata(),
Expand All @@ -111,6 +115,15 @@ def to_reference(self) -> LocalClassReference:
def get_class_name(self) -> str:
return self._class_name

def _model_validator(self, pre: bool = False) -> AST.FunctionInvocation:
return AST.FunctionInvocation(
function_definition=AST.Reference(
qualified_name_excluding_import=("model_validator",),
import_=Pydantic.PydanticImport(),
),
kwargs=[("mode", AST.Expression(expression='"before"' if pre else '"after"'))],
)

def add_field(
self,
*,
Expand Down Expand Up @@ -262,22 +275,22 @@ def add_method_unsafe(
) -> AST.FunctionDeclaration:
return self._pydantic_model.add_method(declaration=declaration, decorator=decorator)

def set_root_type_v1_only(
def set_root_type(
self,
root_type: ir_types.TypeReference,
annotation: Optional[AST.Expression] = None,
is_forward_ref: bool = False,
) -> None:
self.set_root_type_unsafe_v1_only(
self.set_root_type_unsafe(
root_type=self.get_type_hint_for_type_reference(root_type),
annotation=annotation,
is_forward_ref=is_forward_ref,
)

def set_root_type_unsafe_v1_only(
def set_root_type_unsafe(
self, root_type: AST.TypeHint, annotation: Optional[AST.Expression] = None, is_forward_ref: bool = False
) -> None:
self._pydantic_model.set_root_type_unsafe_v1_only(root_type=root_type, annotation=annotation)
self._pydantic_model.set_root_type_unsafe(root_type=root_type, annotation=annotation)

def add_ghost_reference(self, type_id: ir_types.TypeId) -> None:
self._pydantic_model.add_ghost_reference(
Expand All @@ -286,10 +299,7 @@ def add_ghost_reference(self, type_id: ir_types.TypeId) -> None:

def finish(self) -> None:
if self._custom_config.include_validators:
if (
self._pydantic_model._v1_root_type is None
and self._custom_config.version == PydanticVersionCompatibility.V1
):
if self._pydantic_model._root_type is None:
self._pydantic_model.add_partial_class()
self._get_validators_generator().add_validators()
if self._model_contains_forward_refs or self._force_update_forward_refs:
Expand All @@ -309,11 +319,11 @@ def finish(self) -> None:
self._pydantic_model.finish()

def _get_validators_generator(self) -> ValidatorsGenerator:
v1_root_type = self._pydantic_model.get_root_type_unsafe_v1_only()
if v1_root_type is not None and self._custom_config.version == PydanticVersionCompatibility.V1:
return PydanticV1CustomRootTypeValidatorsGenerator(
root_type = self._pydantic_model.get_root_type_unsafe()
if root_type is not None:
return PydanticCustomRootTypeValidatorsGenerator(
model=self._pydantic_model,
root_type=v1_root_type,
root_type=root_type,
)
else:
unique_name = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ def get_dict_method(writer: AST.NodeWriter) -> None:
)

if self._custom_config.version == PydanticVersionCompatibility.V1:
external_pydantic_model.set_root_type_unsafe_v1_only(
external_pydantic_model.set_root_type_unsafe(
is_forward_ref=True,
root_type=root_type,
annotation=AST.Expression(
Expand Down
Loading
Loading