Skip to content

Commit

Permalink
Refactor generator
Browse files Browse the repository at this point in the history
  • Loading branch information
treiher committed Jun 7, 2022
1 parent 1dd1974 commit 168abf6
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 93 deletions.
199 changes: 116 additions & 83 deletions rflx/generator/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@
log = logging.getLogger(__name__)


class Generator: # pylint: disable = too-many-instance-attributes
class Generator:
def __init__(
self,
prefix: str = "",
Expand All @@ -127,12 +127,7 @@ def __init__(
self._reproducible = reproducible
self._debug = debug
self._ignore_unsupported_checksum = ignore_unsupported_checksum
self._parser = ParserGenerator(self._prefix)
self._serializer = SerializerGenerator(self._prefix)

self._executor = ProcessPoolExecutor(max_workers=workers)
self._units: ty.Dict[ID, Unit] = {}

self._template_dir = Path(pkg_resources.resource_filename(*const.TEMPLATE_DIR))
assert self._template_dir.is_dir(), "template directory not found"

Expand All @@ -144,13 +139,17 @@ def generate(
library_files: bool = True,
top_level_package: bool = True,
) -> None:
self._generate(model, integration)
self._write_files(directory, library_files, top_level_package)
units = self._generate(model, integration)
self._write_files(units, directory, library_files, top_level_package)

def _write_files(
self, directory: Path, library_files: bool = True, top_level_package: bool = True
self,
units: dict[ID, Unit],
directory: Path,
library_files: bool = True,
top_level_package: bool = True,
) -> None:
self._write_units(directory)
self._write_units(units, directory)
if library_files:
self._write_library_files(directory)
if top_level_package:
Expand Down Expand Up @@ -211,25 +210,28 @@ def _write_top_level_package(self, directory: Path) -> None:
self._license_header() + f"package {self._prefix} is\n\nend {self._prefix};",
)

def _write_units(self, directory: Path) -> None:
for unit in self._units.values():
def _write_units(self, units: dict[ID, Unit], directory: Path) -> None:
for unit in units.values():
create_file(directory / Path(unit.name + ".ads"), self._license_header() + unit.ads)

if unit.adb:
create_file(directory / Path(unit.name + ".adb"), self._license_header() + unit.adb)

def _generate(self, model: Model, integration: Integration) -> None:
def _generate(self, model: Model, integration: Integration) -> dict[ID, Unit]:
units: dict[ID, Unit] = {}

for t in model.types:
if t.package in [BUILTINS_PACKAGE, INTERNAL_PACKAGE]:
continue

log.info("Generating %s", t.identifier)

if t.package not in self._units:
self._create_unit(ID(t.package), terminating=False)
if t.package not in units:
unit = self._create_unit(ID(t.package), terminating=False)
units[ID(t.package)] = unit

if isinstance(t, (Scalar, Composite)):
self._create_type(t, ID(t.package))
units.update(self._create_type(t, ID(t.package), units))

elif isinstance(t, Message):
# ISSUE: Componolit/RecordFlux#276
Expand All @@ -245,31 +247,38 @@ def _generate(self, model: Model, integration: Integration) -> None:
"unsupported checksum ignored", Subsystem.GENERATOR, location=c.location
)

self._create_message(t)
units.update(self._create_message(t))

elif isinstance(t, Refinement):
self._create_refinement(t)
units.update(self._create_refinement(t, units))

else:
assert False, f'unexpected type "{type(t).__name__}"'

for s in model.sessions:
log.info("Generating %s", s.identifier)

if s.package not in self._units:
self._create_unit(ID(s.package), terminating=False)
if s.package not in units:
unit = self._create_unit(ID(s.package), terminating=False)
units[ID(s.package)] = unit

units.update(self._create_session(s, integration))

self._create_session(s, integration)
return units

def _create_session(self, session: Session, integration: Integration) -> None:
def _create_session(self, session: Session, integration: Integration) -> dict[ID, Unit]:
units: dict[ID, Unit] = {}
allocator_generator = AllocatorGenerator(session, integration, self._prefix)

if allocator_generator.required:
unit = self._create_unit(
allocator_generator.unit_identifier,
allocator_generator.declaration_context,
allocator_generator.body_context,
)
unit += allocator_generator.unit_part
units[allocator_generator.unit_identifier] = unit

session_generator = SessionGenerator(
session, allocator_generator, self._prefix, debug=self._debug
)
Expand All @@ -281,6 +290,9 @@ def _create_session(self, session: Session, integration: Integration) -> None:
terminating=False,
)
unit += session_generator.unit_part
units[session_generator.unit_identifier] = unit

return units

def _create_unit( # pylint: disable = too-many-arguments
self,
Expand Down Expand Up @@ -311,28 +323,27 @@ def _create_unit( # pylint: disable = too-many-arguments
[*configuration_pragmas, *const.CONFIGURATION_PRAGMAS, *body_context],
PackageBody(self._prefix * identifier, aspects=[SparkMode()]),
)
self._units[identifier] = unit

return unit

@staticmethod
def _create_instantiation_unit(
self,
identifier: ID,
context: ty.List[ContextItem],
instantiation: GenericPackageInstantiation,
) -> InstantiationUnit:
for p in reversed(const.CONFIGURATION_PRAGMAS):
context.insert(0, p)

unit = InstantiationUnit(context, instantiation)
self._units[identifier] = unit

return unit

# pylint: disable = too-many-branches
def _create_message(self, message: Message) -> None:
# pylint: disable = too-many-branches, too-many-locals
def _create_message(self, message: Message) -> dict[ID, Unit]:
units: dict[ID, Unit] = {}

if not message.fields:
return
return units

context: ty.List[ContextItem] = []

Expand Down Expand Up @@ -364,6 +375,7 @@ def _create_message(self, message: Message) -> None:
context.append(WithClause(self._prefix * ID(field_type.identifier)))

unit = self._create_unit(ID(message.identifier), context)
units[ID(message.identifier)] = unit

scalar_fields = {}
composite_fields = []
Expand All @@ -386,11 +398,14 @@ def _create_message(self, message: Message) -> None:
else:
fields_with_explicit_size.append(f)

parser_generator = ParserGenerator(self._prefix)
serializer_generator = SerializerGenerator(self._prefix)

futures = [
self._executor.submit(
message_generator.create_use_type_clause,
composite_fields,
self._serializer.requires_set_procedure(message),
serializer_generator.requires_set_procedure(message),
),
self._executor.submit(message_generator.create_allow_unevaluated_use_of_old),
self._executor.submit(message_generator.create_field_type, message),
Expand Down Expand Up @@ -482,49 +497,51 @@ def _create_message(self, message: Message) -> None:
else []
),
self._executor.submit(
self._parser.create_get_function, message, scalar_fields, composite_fields
parser_generator.create_get_function, message, scalar_fields, composite_fields
),
self._executor.submit(
parser_generator.create_verify_procedure, message, scalar_fields, composite_fields
),
self._executor.submit(parser_generator.create_verify_message_procedure, message),
self._executor.submit(parser_generator.create_present_function),
self._executor.submit(parser_generator.create_structural_valid_function),
self._executor.submit(parser_generator.create_valid_function),
self._executor.submit(parser_generator.create_incomplete_function),
self._executor.submit(parser_generator.create_invalid_function),
self._executor.submit(
self._parser.create_verify_procedure, message, scalar_fields, composite_fields
parser_generator.create_structural_valid_message_function, message
),
self._executor.submit(self._parser.create_verify_message_procedure, message),
self._executor.submit(self._parser.create_present_function),
self._executor.submit(self._parser.create_structural_valid_function),
self._executor.submit(self._parser.create_valid_function),
self._executor.submit(self._parser.create_incomplete_function),
self._executor.submit(self._parser.create_invalid_function),
self._executor.submit(self._parser.create_structural_valid_message_function, message),
self._executor.submit(self._parser.create_valid_message_function, message),
self._executor.submit(self._parser.create_incomplete_message_function),
self._executor.submit(self._parser.create_scalar_getter_functions, scalar_fields),
self._executor.submit(self._parser.create_opaque_getter_functions, opaque_fields),
self._executor.submit(self._parser.create_opaque_getter_procedures, opaque_fields),
self._executor.submit(parser_generator.create_valid_message_function, message),
self._executor.submit(parser_generator.create_incomplete_message_function),
self._executor.submit(parser_generator.create_scalar_getter_functions, scalar_fields),
self._executor.submit(parser_generator.create_opaque_getter_functions, opaque_fields),
self._executor.submit(parser_generator.create_opaque_getter_procedures, opaque_fields),
self._executor.submit(
self._parser.create_generic_opaque_getter_procedures, opaque_fields
parser_generator.create_generic_opaque_getter_procedures, opaque_fields
),
self._executor.submit(self._serializer.create_valid_size_function, message),
self._executor.submit(self._serializer.create_valid_length_function),
self._executor.submit(serializer_generator.create_valid_size_function, message),
self._executor.submit(serializer_generator.create_valid_length_function),
self._executor.submit(
self._serializer.create_set_procedure, message, scalar_fields, composite_fields
serializer_generator.create_set_procedure, message, scalar_fields, composite_fields
),
self._executor.submit(
self._serializer.create_scalar_setter_procedures, message, scalar_fields
serializer_generator.create_scalar_setter_procedures, message, scalar_fields
),
self._executor.submit(
self._serializer.create_composite_setter_empty_procedures, message
serializer_generator.create_composite_setter_empty_procedures, message
),
self._executor.submit(
self._serializer.create_sequence_setter_procedures, message, sequence_fields
serializer_generator.create_sequence_setter_procedures, message, sequence_fields
),
self._executor.submit(
self._serializer.create_composite_initialize_procedures,
serializer_generator.create_composite_initialize_procedures,
message,
fields_with_explicit_size,
fields_with_implicit_size,
),
self._executor.submit(self._serializer.create_opaque_setter_procedures, message),
self._executor.submit(serializer_generator.create_opaque_setter_procedures, message),
self._executor.submit(
self._serializer.create_generic_opaque_setter_procedures, message
serializer_generator.create_generic_opaque_setter_procedures, message
),
self._executor.submit(
message_generator.create_switch_procedures, message, sequence_fields, self._prefix
Expand All @@ -544,6 +561,8 @@ def _create_message(self, message: Message) -> None:
for future in futures:
unit += future.result()

return units

@staticmethod
def _requires_composite_field_function(
message: Message,
Expand All @@ -557,17 +576,21 @@ def _requires_composite_field_function(
or sequence_fields
)

def _create_refinement(self, refinement: Refinement) -> None:
unit_name = refinement.package * const.REFINEMENT_PACKAGE
def _create_refinement(
self, refinement: Refinement, units: ty.Mapping[ID, Unit]
) -> dict[ID, Unit]:
result: dict[ID, Unit] = {}
identifier = refinement.package * const.REFINEMENT_PACKAGE
null_sdu = not refinement.sdu.fields

if unit_name in self._units:
unit = self._units[unit_name]
if identifier in units:
unit = units[identifier]
else:
unit = self._create_unit(
unit_name,
identifier,
[WithClause(self._prefix * const.TYPES_PACKAGE)] if not null_sdu else [],
)
result[identifier] = unit

assert isinstance(unit, PackageUnit), "unexpected unit type"

Expand Down Expand Up @@ -627,10 +650,16 @@ def _create_refinement(self, refinement: Refinement) -> None:
unit += self._create_switch_procedure(refinement, condition_fields)
unit += self._create_copy_refined_field_procedure(refinement, condition_fields)

def _create_type(self, field_type: Type, message_package: ID) -> None:
return result

def _create_type(
self, field_type: Type, message_package: ID, units: ty.Mapping[ID, Unit]
) -> dict[ID, Unit]:
assert field_type.package != BUILTINS_PACKAGE

unit = self._units[message_package]
result: dict[ID, Unit] = {}

unit = units[message_package]

assert isinstance(unit, PackageUnit)

Expand All @@ -647,32 +676,36 @@ def _create_type(self, field_type: Type, message_package: ID) -> None:
unit += UnitPart(enumeration_types(field_type))
unit += self._enumeration_functions(field_type)
elif isinstance(field_type, Sequence):
self._create_sequence_unit(field_type)
result.update(self._create_sequence(field_type))
else:
assert False, f'unexpected type "{type(field_type).__name__}"'

def _create_sequence_unit(self, sequence_type: Sequence) -> None:
return result

def _create_sequence(self, sequence_type: Sequence) -> dict[ID, Unit]:
context, package = common.create_sequence_instantiation(sequence_type, self._prefix)
self._create_instantiation_unit(
package.identifier,
[
Pragma("SPARK_Mode"),
*context,
# WORKAROUND: Componolit/Workarounds#33
# A compiler error about a non-visible declaration of RFLX_Types inside the
# generic sequence package is prevented by adding a with-clause for this package.
Pragma(
"Warnings",
[Variable("Off"), String('unit "*RFLX_Types" is not referenced')],
),
WithClause(self._prefix * const.TYPES_PACKAGE),
Pragma(
"Warnings",
[Variable("On"), String('unit "*RFLX_Types" is not referenced')],
),
],
package,
)
return {
package.identifier: self._create_instantiation_unit(
[
Pragma("SPARK_Mode"),
*context,
# ISSUE: Componolit/Workarounds#33
# A compiler error about a non-visible declaration of RFLX_Types inside the
# generic sequence package is prevented by adding a with-clause for this
# package.
Pragma(
"Warnings",
[Variable("Off"), String('unit "*RFLX_Types" is not referenced')],
),
WithClause(self._prefix * const.TYPES_PACKAGE),
Pragma(
"Warnings",
[Variable("On"), String('unit "*RFLX_Types" is not referenced')],
),
],
package,
)
}

def _integer_functions(self, integer: Integer) -> UnitPart:
specification: ty.List[Declaration] = []
Expand Down
Loading

0 comments on commit 168abf6

Please sign in to comment.