From eddfeaf75dd3af17570bf8fae77ea5140e1589a3 Mon Sep 17 00:00:00 2001 From: Chris Riccomini <criccomini@apache.org> Date: Wed, 11 Dec 2024 22:39:06 -0800 Subject: [PATCH] Add support for scalar option generation --- proto_schema_parser/generator.py | 8 +++++-- tests/test_generator.py | 41 ++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 2 deletions(-) diff --git a/proto_schema_parser/generator.py b/proto_schema_parser/generator.py index 46dad27..d2d6f4c 100644 --- a/proto_schema_parser/generator.py +++ b/proto_schema_parser/generator.py @@ -1,5 +1,5 @@ import itertools -from typing import List +from typing import List, cast from proto_schema_parser import ast @@ -118,9 +118,13 @@ def _generate_field(self, field: ast.Field, indent_level: int = 0) -> str: cardinality = f"{field.cardinality.value.lower()} " options = "" + print(field.options) if field.options: options = " [" - options += ", ".join(f'{opt.name} = "{opt.value}"' for opt in field.options) + options += ", ".join( + f"{opt.name} = {self._generate_scalar(cast(ast.ScalarValue, opt.value))}" + for opt in field.options + ) options += "]" return f"{' ' * indent_level}{cardinality}{field.type} {field.name} = {field.number}{options};" diff --git a/tests/test_generator.py b/tests/test_generator.py index 0c2ab82..1bc547f 100644 --- a/tests/test_generator.py +++ b/tests/test_generator.py @@ -1119,3 +1119,44 @@ def test_generate_option_with_multiple_message_literals(): ) assert result == expected + + +def test_generate_field_with_options_for_scalar_types(): + generator = Generator() + test_cases = [ + # Numeric types + ("double", 3.14, "double field_double = 1 [default = 3.14];"), + ("float", -1.23, "float field_float = 2 [default = -1.23];"), + ("int32", -42, "int32 field_int32 = 3 [default = -42];"), + ("int64", 9876543210, "int64 field_int64 = 4 [default = 9876543210];"), + ("uint32", 42, "uint32 field_uint32 = 5 [default = 42];"), + ("uint64", 1234567890, "uint64 field_uint64 = 6 [default = 1234567890];"), + ("sint32", -123, "sint32 field_sint32 = 7 [default = -123];"), + ("sint64", -9876543210, "sint64 field_sint64 = 8 [default = -9876543210];"), + ("fixed32", 456, "fixed32 field_fixed32 = 9 [default = 456];"), + ("fixed64", 1234567890, "fixed64 field_fixed64 = 10 [default = 1234567890];"), + ("sfixed32", -789, "sfixed32 field_sfixed32 = 11 [default = -789];"), + ( + "sfixed64", + -1234567890, + "sfixed64 field_sfixed64 = 12 [default = -1234567890];", + ), + # Boolean type + ("bool", True, "bool field_bool = 13 [default = true];"), + # String types + ( + "string", + "test string", + 'string field_string = 14 [default = "test string"];', + ), + ] + + for type_name, default_value, expected in test_cases: + field = ast.Field( + name=f"field_{type_name}", + type=type_name, + number=test_cases.index((type_name, default_value, expected)) + 1, + options=[ast.Option(name="default", value=default_value)], + ) + result = generator._generate_field(field) + assert result == expected