diff --git a/README.md b/README.md index 8b125a3e0..2a1609a4a 100644 --- a/README.md +++ b/README.md @@ -438,7 +438,10 @@ Template customization: Wrap string literal by using black `experimental- string-processing` option (require black 20.8b0 or later) - --additional-imports Custom imports for output (delimited list input) + --additional-imports Custom imports for output (delimited list input). + For example "datetime.date,datetime.datetime" + --custom-formatters List of modules with custom formatter (delimited list input). + --custom-formatters-kwargs A file with kwargs for custom formatters. OpenAPI-only options: --openapi-scopes {schemas,paths,tags,parameters} [{schemas,paths,tags,parameters} ...] diff --git a/datamodel_code_generator/__init__.py b/datamodel_code_generator/__init__.py index be0857fed..9df843f77 100644 --- a/datamodel_code_generator/__init__.py +++ b/datamodel_code_generator/__init__.py @@ -298,6 +298,8 @@ def generate( keep_model_order: bool = False, custom_file_header: Optional[str] = None, custom_file_header_path: Optional[Path] = None, + custom_formatters: Optional[List[str]] = None, + custom_formatters_kwargs: Optional[Dict[str, Any]] = None, ) -> None: remote_text_cache: DefaultPutDict[str, str] = DefaultPutDict() if isinstance(input_, str): @@ -452,6 +454,8 @@ def get_header_and_first_line(csv_file: IO[str]) -> Dict[str, Any]: capitalise_enum_members=capitalise_enum_members, keep_model_order=keep_model_order, known_third_party=data_model_types.known_third_party, + custom_formatters=custom_formatters, + custom_formatters_kwargs=custom_formatters_kwargs, **kwargs, ) diff --git a/datamodel_code_generator/__main__.py b/datamodel_code_generator/__main__.py index 690051493..07c5472ef 100644 --- a/datamodel_code_generator/__main__.py +++ b/datamodel_code_generator/__main__.py @@ -120,7 +120,9 @@ class Config: def get_fields(cls) -> Dict[str, Any]: return cls.__fields__ - @field_validator('aliases', 'extra_template_data', mode='before') + @field_validator( + 'aliases', 'extra_template_data', 'custom_formatters_kwargs', mode='before' + ) def validate_file(cls, value: Any) -> Optional[TextIOBase]: if value is None or isinstance(value, TextIOBase): return value @@ -204,6 +206,14 @@ def validate_additional_imports(cls, values: Dict[str, Any]) -> Dict[str, Any]: values['additional_imports'] = [] return values + @model_validator(mode='before') + def validate_custom_formatters(cls, values: Dict[str, Any]) -> Dict[str, Any]: + if values.get('custom_formatters') is not None: + values['custom_formatters'] = values.get('custom_formatters').split(',') + else: + values['custom_formatters'] = [] + return values + if PYDANTIC_V2: @model_validator(mode='after') # type: ignore @@ -282,6 +292,8 @@ def validate_root(cls, values: Any) -> Any: keep_model_order: bool = False custom_file_header: Optional[str] = None custom_file_header_path: Optional[Path] = None + custom_formatters: Optional[List[str]] = None + custom_formatters_kwargs: Optional[TextIOBase] = None def merge_args(self, args: Namespace) -> None: set_args = { @@ -391,6 +403,28 @@ def main(args: Optional[Sequence[str]] = None) -> Exit: ) return Exit.ERROR + if config.custom_formatters_kwargs is None: + custom_formatters_kwargs = None + else: + with config.custom_formatters_kwargs as data: + try: + custom_formatters_kwargs = json.load(data) + except json.JSONDecodeError as e: + print( + f'Unable to load custom_formatters_kwargs mapping: {e}', + file=sys.stderr, + ) + return Exit.ERROR + if not isinstance(custom_formatters_kwargs, dict) or not all( + isinstance(k, str) and isinstance(v, str) + for k, v in custom_formatters_kwargs.items() + ): + print( + 'Custom formatters kwargs mapping must be a JSON string mapping (e.g. {"from": "to", ...})', + file=sys.stderr, + ) + return Exit.ERROR + try: generate( input_=config.url or config.input or sys.stdin.read(), @@ -452,6 +486,8 @@ def main(args: Optional[Sequence[str]] = None) -> Exit: keep_model_order=config.keep_model_order, custom_file_header=config.custom_file_header, custom_file_header_path=config.custom_file_header_path, + custom_formatters=config.custom_formatters, + custom_formatters_kwargs=custom_formatters_kwargs, ) return Exit.OK except InvalidClassNameError as e: diff --git a/datamodel_code_generator/arguments.py b/datamodel_code_generator/arguments.py index 46d4b7e4c..065e12444 100644 --- a/datamodel_code_generator/arguments.py +++ b/datamodel_code_generator/arguments.py @@ -387,10 +387,21 @@ def start_section(self, heading: Optional[str]) -> None: ) base_options.add_argument( '--additional-imports', - help='Custom imports for output (delimited list input)', + help='Custom imports for output (delimited list input). For example "datetime.date,datetime.datetime"', type=str, default=None, ) +base_options.add_argument( + '--custom-formatters', + help='List of modules with custom formatter (delimited list input).', + type=str, + default=None, +) +template_options.add_argument( + '--custom-formatters-kwargs', + help='A file with kwargs for custom formatters.', + type=FileType('rt'), +) # ====================================================================================== # Options specific to OpenAPI input schemas diff --git a/datamodel_code_generator/format.py b/datamodel_code_generator/format.py index e2b84d2b5..b38740d6d 100644 --- a/datamodel_code_generator/format.py +++ b/datamodel_code_generator/format.py @@ -1,6 +1,7 @@ from __future__ import annotations from enum import Enum +from importlib import import_module from pathlib import Path from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence from warnings import warn @@ -112,6 +113,8 @@ def __init__( wrap_string_literal: Optional[bool] = None, skip_string_normalization: bool = True, known_third_party: Optional[List[str]] = None, + custom_formatters: Optional[List[str]] = None, + custom_formatters_kwargs: Optional[Dict[str, Any]] = None, ) -> None: if not settings_path: settings_path = Path().resolve() @@ -167,12 +170,49 @@ def __init__( settings_path=self.settings_path, **self.isort_config_kwargs ) + self.custom_formatters_kwargs = custom_formatters_kwargs or {} + self.custom_formatters = self._check_custom_formatters(custom_formatters) + + def _load_custom_formatter( + self, custom_formatter_import: str + ) -> CustomCodeFormatter: + import_ = import_module(custom_formatter_import) + + if not hasattr(import_, 'CodeFormatter'): + raise NameError( + f'Custom formatter module `{import_.__name__}` must contains object with name Formatter' + ) + + formatter_class = import_.__getattribute__('CodeFormatter') + + if not issubclass(formatter_class, CustomCodeFormatter): + raise TypeError( + f'The custom module {custom_formatter_import} must inherit from `datamodel-code-generator`' + ) + + return formatter_class(formatter_kwargs=self.custom_formatters_kwargs) + + def _check_custom_formatters( + self, custom_formatters: Optional[List[str]] + ) -> List[CustomCodeFormatter]: + if custom_formatters is None: + return [] + + return [ + self._load_custom_formatter(custom_formatter_import) + for custom_formatter_import in custom_formatters + ] + def format_code( self, code: str, ) -> str: code = self.apply_isort(code) code = self.apply_black(code) + + for formatter in self.custom_formatters: + code = formatter.apply(code) + return code def apply_black(self, code: str) -> str: @@ -200,3 +240,11 @@ def apply_isort(self, code: str) -> str: def apply_isort(self, code: str) -> str: return isort.code(code, config=self.isort_config) + + +class CustomCodeFormatter: + def __init__(self, formatter_kwargs: Dict[str, Any]) -> None: + self.formatter_kwargs = formatter_kwargs + + def apply(self, code: str) -> str: + raise NotImplementedError diff --git a/datamodel_code_generator/parser/base.py b/datamodel_code_generator/parser/base.py index d29da9058..04e44ea22 100644 --- a/datamodel_code_generator/parser/base.py +++ b/datamodel_code_generator/parser/base.py @@ -386,6 +386,8 @@ def __init__( keep_model_order: bool = False, use_one_literal_as_default: bool = False, known_third_party: Optional[List[str]] = None, + custom_formatters: Optional[List[str]] = None, + custom_formatters_kwargs: Optional[Dict[str, Any]] = None, ) -> None: self.data_type_manager: DataTypeManager = data_type_manager_type( python_version=target_python_version, @@ -502,6 +504,8 @@ def __init__( self.keep_model_order = keep_model_order self.use_one_literal_as_default = use_one_literal_as_default self.known_third_party = known_third_party + self.custom_formatter = custom_formatters + self.custom_formatters_kwargs = custom_formatters_kwargs @property def iter_source(self) -> Iterator[Source]: @@ -1143,6 +1147,8 @@ def parse( self.wrap_string_literal, skip_string_normalization=not self.use_double_quotes, known_third_party=self.known_third_party, + custom_formatters=self.custom_formatter, + custom_formatters_kwargs=self.custom_formatters_kwargs, ) else: code_formatter = None diff --git a/datamodel_code_generator/parser/graphql.py b/datamodel_code_generator/parser/graphql.py index 57edd7c84..0d4c8226c 100644 --- a/datamodel_code_generator/parser/graphql.py +++ b/datamodel_code_generator/parser/graphql.py @@ -154,6 +154,8 @@ def __init__( keep_model_order: bool = False, use_one_literal_as_default: bool = False, known_third_party: Optional[List[str]] = None, + custom_formatters: Optional[List[str]] = None, + custom_formatters_kwargs: Optional[Dict[str, Any]] = None, ) -> None: super().__init__( source=source, @@ -217,6 +219,8 @@ def __init__( capitalise_enum_members=capitalise_enum_members, keep_model_order=keep_model_order, known_third_party=known_third_party, + custom_formatters=custom_formatters, + custom_formatters_kwargs=custom_formatters_kwargs, ) self.data_model_scalar_type = data_model_scalar_type diff --git a/datamodel_code_generator/parser/jsonschema.py b/datamodel_code_generator/parser/jsonschema.py index 20f4d9fcc..1cbcbff42 100644 --- a/datamodel_code_generator/parser/jsonschema.py +++ b/datamodel_code_generator/parser/jsonschema.py @@ -422,6 +422,8 @@ def __init__( capitalise_enum_members: bool = False, keep_model_order: bool = False, known_third_party: Optional[List[str]] = None, + custom_formatters: Optional[List[str]] = None, + custom_formatters_kwargs: Optional[Dict[str, Any]] = None, ) -> None: super().__init__( source=source, @@ -485,6 +487,8 @@ def __init__( capitalise_enum_members=capitalise_enum_members, keep_model_order=keep_model_order, known_third_party=known_third_party, + custom_formatters=custom_formatters, + custom_formatters_kwargs=custom_formatters_kwargs, ) self.remote_object_cache: DefaultPutDict[str, Dict[str, Any]] = DefaultPutDict() diff --git a/datamodel_code_generator/parser/openapi.py b/datamodel_code_generator/parser/openapi.py index c2ee816d2..3bc6c2061 100644 --- a/datamodel_code_generator/parser/openapi.py +++ b/datamodel_code_generator/parser/openapi.py @@ -218,6 +218,8 @@ def __init__( capitalise_enum_members: bool = False, keep_model_order: bool = False, known_third_party: Optional[List[str]] = None, + custom_formatters: Optional[List[str]] = None, + custom_formatters_kwargs: Optional[Dict[str, Any]] = None, ): super().__init__( source=source, @@ -281,6 +283,8 @@ def __init__( capitalise_enum_members=capitalise_enum_members, keep_model_order=keep_model_order, known_third_party=known_third_party, + custom_formatters=custom_formatters, + custom_formatters_kwargs=custom_formatters_kwargs, ) self.open_api_scopes: List[OpenAPIScope] = openapi_scopes or [ OpenAPIScope.Schemas diff --git a/docs/custom-formatters.md b/docs/custom-formatters.md new file mode 100644 index 000000000..1768a8f35 --- /dev/null +++ b/docs/custom-formatters.md @@ -0,0 +1,23 @@ +# Custom Code Formatters + +New features of the `datamodel-code-generator` it is custom code formatters. + +## Usage +To use the `--custom-formatters` option, you'll need to pass the module with your formatter. For example + +**your_module.py** +```python +from datamodel_code_generator.format import CustomCodeFormatter + +class CodeFormatter(CustomCodeFormatter): + def apply(self, code: str) -> str: + # processed code + return ... + +``` + +and run the following command + +```sh +$ datamodel-codegen --input {your_input_file} --output {your_output_file} --custom-formatters "{path_to_your_module}.your_module" +``` diff --git a/docs/index.md b/docs/index.md index 600a4469c..158c73360 100644 --- a/docs/index.md +++ b/docs/index.md @@ -435,8 +435,11 @@ Template customization: Wrap string literal by using black `experimental- string-processing` option (require black 20.8b0 or later) - --additional-imports Custom imports for output (delimited list input) - + --additional-imports Custom imports for output (delimited list input). + For example "datetime.date,datetime.datetime" + --custom-formatters List of modules with custom formatter (delimited list input). + --custom-formatters-kwargs A file with kwargs for custom formatters. + OpenAPI-only options: --openapi-scopes {schemas,paths,tags,parameters} [{schemas,paths,tags,parameters} ...] Scopes of OpenAPI model generation (default: schemas) diff --git a/docs/supported-data-types.md b/docs/supported-data-types.md index 46f4b1fa9..adf9200ea 100644 --- a/docs/supported-data-types.md +++ b/docs/supported-data-types.md @@ -6,6 +6,7 @@ This code generator supports the following input formats: - JSON Schema ([JSON Schema Core](http://json-schema.org/draft/2019-09/json-schema-validation.html) /[JSON Schema Validation](http://json-schema.org/draft/2019-09/json-schema-validation.html)) - JSON/YAML Data (it will be converted to JSON Schema) - Python dictionary (it will be converted to JSON Schema) +- GraphQL schema ([GraphQL Schemas and Types](https://graphql.org/learn/schema/)) ## Implemented data types and features diff --git a/mkdocs.yml b/mkdocs.yml index 4ff987bc5..34547e10a 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -26,6 +26,7 @@ nav: - Generate from JSON Data: jsondata.md - Generate from GraphQL Schema: graphql.md - Custom template: custom_template.md + - Custom formatters: custom-formatters.md - Using as module: using_as_module.md - Formatting: formatting.md - Field Constraints: field-constraints.md diff --git a/tests/data/expected/main/main_graphql_custom_formatters/output.py b/tests/data/expected/main/main_graphql_custom_formatters/output.py new file mode 100644 index 000000000..10f9c4950 --- /dev/null +++ b/tests/data/expected/main/main_graphql_custom_formatters/output.py @@ -0,0 +1,37 @@ +# generated by datamodel-codegen: +# filename: custom-scalar-types.graphql +# timestamp: 2019-07-26T00:00:00+00:00 + +# a comment +from __future__ import annotations + +from typing import Optional, TypeAlias + +from pydantic import BaseModel, Field +from typing_extensions import Literal + +Boolean: TypeAlias = bool +""" +The `Boolean` scalar type represents `true` or `false`. +""" + + +ID: TypeAlias = str +""" +The `ID` scalar type represents a unique identifier, often used to refetch an object or as key for a cache. The ID type appears in a JSON response as a String; however, it is not intended to be human-readable. When expected as an input type, any string (such as `"4"`) or integer (such as `4`) input value will be accepted as an ID. +""" + + +Long: TypeAlias = str + + +String: TypeAlias = str +""" +The `String` scalar type represents textual data, represented as UTF-8 character sequences. The String type is most often used by GraphQL to represent free-form human-readable text. +""" + + +class A(BaseModel): + duration: Long + id: ID + typename__: Optional[Literal['A']] = Field('A', alias='__typename') diff --git a/tests/data/python/custom_formatters/add_comment.py b/tests/data/python/custom_formatters/add_comment.py new file mode 100644 index 000000000..bf539928c --- /dev/null +++ b/tests/data/python/custom_formatters/add_comment.py @@ -0,0 +1,7 @@ +from datamodel_code_generator.format import CustomCodeFormatter + + +class CodeFormatter(CustomCodeFormatter): + """Simple correct formatter. Adding a comment to top of code.""" + def apply(self, code: str) -> str: + return f'# a comment\n{code}' diff --git a/tests/data/python/custom_formatters/add_license.py b/tests/data/python/custom_formatters/add_license.py new file mode 100644 index 000000000..710b3fe84 --- /dev/null +++ b/tests/data/python/custom_formatters/add_license.py @@ -0,0 +1,24 @@ +from typing import Any, Dict +from pathlib import Path + +from datamodel_code_generator.format import CustomCodeFormatter + + +class CodeFormatter(CustomCodeFormatter): + """Add a license to file from license file path.""" + + def __init__(self, formatter_kwargs: Dict[str, Any]) -> None: + super().__init__(formatter_kwargs) + + if 'license_file' not in formatter_kwargs: + raise ValueError() + + license_file_path = Path(formatter_kwargs['license_file']).resolve() + + with license_file_path.open("r") as f: + license_file = f.read() + + self.license_header = '\n'.join([f'# {line}' for line in license_file.split('\n')]) + + def apply(self, code: str) -> str: + return f'{self.license_header}\n{code}' diff --git a/tests/data/python/custom_formatters/license_example.txt b/tests/data/python/custom_formatters/license_example.txt new file mode 100644 index 000000000..8a12fbe07 --- /dev/null +++ b/tests/data/python/custom_formatters/license_example.txt @@ -0,0 +1,3 @@ +MIT License + +Copyright (c) 2023 Blah-blah diff --git a/tests/data/python/custom_formatters/not_subclass.py b/tests/data/python/custom_formatters/not_subclass.py new file mode 100644 index 000000000..231e2fb04 --- /dev/null +++ b/tests/data/python/custom_formatters/not_subclass.py @@ -0,0 +1,3 @@ +class CodeFormatter: + """Invalid formatter: is not subclass of `datamodel_code_generator.format.CustomCodeFormatter`.""" + pass diff --git a/tests/data/python/custom_formatters/wrong.py b/tests/data/python/custom_formatters/wrong.py new file mode 100644 index 000000000..980e033fd --- /dev/null +++ b/tests/data/python/custom_formatters/wrong.py @@ -0,0 +1,7 @@ +from datamodel_code_generator.format import CustomCodeFormatter + + +class WrongFormatterName(CustomCodeFormatter): + """Invalid formatter: correct name is CodeFormatter.""" + def apply(self, code: str) -> str: + return f'# a comment\n{code}' diff --git a/tests/test_format.py b/tests/test_format.py index 7054a11b8..4274f093d 100644 --- a/tests/test_format.py +++ b/tests/test_format.py @@ -1,9 +1,20 @@ import sys +from pathlib import Path import pytest from datamodel_code_generator.format import CodeFormatter, PythonVersion +EXAMPLE_LICENSE_FILE = str( + Path(__file__).parent / 'data/python/custom_formatters/license_example.txt' +) + +UN_EXIST_FORMATTER = 'tests.data.python.custom_formatters.un_exist' +WRONG_FORMATTER = 'tests.data.python.custom_formatters.wrong' +NOT_SUBCLASS_FORMATTER = 'tests.data.python.custom_formatters.not_subclass' +ADD_COMMENT_FORMATTER = 'tests.data.python.custom_formatters.add_comment' +ADD_LICENSE_FORMATTER = 'tests.data.python.custom_formatters.add_license' + def test_python_version(): """Ensure that the python version used for the tests is properly listed""" @@ -28,3 +39,84 @@ def test_format_code_with_skip_string_normalization( formatted_code = formatter.format_code("a = 'b'") assert formatted_code == expected_output + '\n' + + +def test_format_code_un_exist_custom_formatter(): + with pytest.raises(ModuleNotFoundError): + _ = CodeFormatter( + PythonVersion.PY_37, + custom_formatters=[UN_EXIST_FORMATTER], + ) + + +def test_format_code_invalid_formatter_name(): + with pytest.raises(NameError): + _ = CodeFormatter( + PythonVersion.PY_37, + custom_formatters=[WRONG_FORMATTER], + ) + + +def test_format_code_is_not_subclass(): + with pytest.raises(TypeError): + _ = CodeFormatter( + PythonVersion.PY_37, + custom_formatters=[NOT_SUBCLASS_FORMATTER], + ) + + +def test_format_code_with_custom_formatter_without_kwargs(): + formatter = CodeFormatter( + PythonVersion.PY_37, + custom_formatters=[ADD_COMMENT_FORMATTER], + ) + + formatted_code = formatter.format_code('x = 1\ny = 2') + + assert formatted_code == '# a comment\nx = 1\ny = 2' + '\n' + + +def test_format_code_with_custom_formatter_with_kwargs(): + formatter = CodeFormatter( + PythonVersion.PY_37, + custom_formatters=[ADD_LICENSE_FORMATTER], + custom_formatters_kwargs={'license_file': EXAMPLE_LICENSE_FILE}, + ) + + formatted_code = formatter.format_code('x = 1\ny = 2') + + assert ( + formatted_code + == """# MIT License +# +# Copyright (c) 2023 Blah-blah +# +x = 1 +y = 2 +""" + ) + + +def test_format_code_with_two_custom_formatters(): + formatter = CodeFormatter( + PythonVersion.PY_37, + custom_formatters=[ + ADD_COMMENT_FORMATTER, + ADD_LICENSE_FORMATTER, + ], + custom_formatters_kwargs={'license_file': EXAMPLE_LICENSE_FILE}, + ) + + formatted_code = formatter.format_code('x = 1\ny = 2') + + assert ( + formatted_code + == """# MIT License +# +# Copyright (c) 2023 Blah-blah +# +# a comment +x = 1 +y = 2 +""" + ) diff --git a/tests/test_main.py b/tests/test_main.py index 25b161c80..c7e0318d6 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -6335,3 +6335,28 @@ def test_main_graphql_additional_imports_isort_5(): / 'output_isort5.py' ).read_text() ) + + +@freeze_time('2019-07-26') +def test_main_graphql_custom_formatters(): + with TemporaryDirectory() as output_dir: + output_file: Path = Path(output_dir) / 'output.py' + return_code: Exit = main( + [ + '--input', + str(GRAPHQL_DATA_PATH / 'custom-scalar-types.graphql'), + '--output', + str(output_file), + '--input-file-type', + 'graphql', + '--custom-formatters', + 'tests.data.python.custom_formatters.add_comment', + ] + ) + assert return_code == Exit.OK + assert ( + output_file.read_text() + == ( + EXPECTED_MAIN_PATH / 'main_graphql_custom_formatters' / 'output.py' + ).read_text() + )