diff --git a/pbj-core/gradle.properties b/pbj-core/gradle.properties index feec7917..b0b8aa7b 100644 --- a/pbj-core/gradle.properties +++ b/pbj-core/gradle.properties @@ -1,5 +1,5 @@ # Version number -version=0.8.7-SNAPSHOT +version=0.9.0-SNAPSHOT # Need increased heap for running Gradle itself, or SonarQube will run the JVM out of metaspace org.gradle.jvmargs=-Xmx2048m diff --git a/pbj-core/pbj-compiler/src/main/antlr/com/hedera/hashgraph/protoparser/grammar/Protobuf3.g4 b/pbj-core/pbj-compiler/src/main/antlr/com/hedera/hashgraph/protoparser/grammar/Protobuf3.g4 index c7a0c359..1f9fbc74 100644 --- a/pbj-core/pbj-compiler/src/main/antlr/com/hedera/hashgraph/protoparser/grammar/Protobuf3.g4 +++ b/pbj-core/pbj-compiler/src/main/antlr/com/hedera/hashgraph/protoparser/grammar/Protobuf3.g4 @@ -91,7 +91,7 @@ oneofField // Map field mapField - : MAP LT keyType COMMA type_ GT mapName + : docComment MAP LT keyType COMMA type_ GT mapName EQ fieldNumber ( LB fieldOptions RB )? SEMI ; keyType diff --git a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/Common.java b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/Common.java index 62c320f3..23006654 100644 --- a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/Common.java +++ b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/Common.java @@ -254,6 +254,8 @@ public static String getFieldsHashCode(final List fields, String generate result = 31 * result + Integer.hashCode($fieldName.protoOrdinal()); } """).replace("$fieldName", f.nameCamelFirstLower()); + } else if (f.type() == Field.FieldType.MAP) { + generatedCodeSoFar += getMapHashCodeGeneration(generatedCodeSoFar, f); } else if (f.type() == Field.FieldType.STRING || f.parent() == null) { // process sub message generatedCodeSoFar += ( @@ -350,6 +352,33 @@ private static String getRepeatedHashCodeGeneration(String generatedCodeSoFar, F return generatedCodeSoFar; } + /** + * Get the hashcode codegen for a map field. + * @param generatedCodeSoFar The string that the codegen is generated into. + * @param f The field for which to generate the hash code. + * @return Updated codegen string. + */ + @NonNull + private static String getMapHashCodeGeneration(String generatedCodeSoFar, final Field f) { + generatedCodeSoFar += ( + """ + for (Object k : ((PbjMap) $fieldName).getSortedKeys()) { + if (k != null) { + result = 31 * result + k.hashCode(); + } else { + result = 31 * result; + } + Object v = $fieldName.get(k); + if (v != null) { + result = 31 * result + v.hashCode(); + } else { + result = 31 * result; + } + } + """).replace("$fieldName", f.nameCamelFirstLower()); + return generatedCodeSoFar; + } + /** * Recursively calculates `equals` statement for a message fields. * @@ -417,6 +446,7 @@ else if (f.repeated()) { } else if (f.type() == Field.FieldType.STRING || f.type() == Field.FieldType.BYTES || f.type() == Field.FieldType.ENUM || + f.type() == Field.FieldType.MAP || f.parent() == null /* Process a sub-message */) { generatedCodeSoFar += ( """ diff --git a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/ContextualLookupHelper.java b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/ContextualLookupHelper.java index 2a16f3bd..849f53fd 100644 --- a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/ContextualLookupHelper.java +++ b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/ContextualLookupHelper.java @@ -91,6 +91,17 @@ public String getPackageFieldMessageType(final FileType fileType, final FieldCon return lookupHelper.getPackage(srcProtoFileContext, fileType, fieldContext.type_().messageType()); } + /** + * Get the Java package a class should be generated into for a given typeContext and file type. + * + * @param fileType The type of file we want the package for + * @param typeContext The field to get package for message type for + * @return java package to put model class in + */ + public String getPackageFieldMessageType(final FileType fileType, final Type_Context typeContext) { + return lookupHelper.getPackage(srcProtoFileContext, fileType, typeContext.messageType()); + } + /** * Get the PBJ Java package a class should be generated into for a given fieldContext and file type. * diff --git a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/Field.java b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/Field.java index 87b29c7a..1dcc3ec8 100644 --- a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/Field.java +++ b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/Field.java @@ -189,49 +189,73 @@ default OneOfField parent() { return null; } + /** + * Extract the name of the Java model class for a message type, + * or null if the type is not a message. + */ + static String extractMessageTypeName(final Protobuf3Parser.Type_Context typeContext) { + return typeContext.messageType() == null ? null : typeContext.messageType().messageName().getText(); + } + + /** + * Extract the name of the Java package for a given FileType for a message type, + * or null if the type is not a message. + */ + static String extractMessageTypePackage( + final Protobuf3Parser.Type_Context typeContext, + final FileType fileType, + final ContextualLookupHelper lookupHelper) { + return typeContext.messageType() == null || typeContext.messageType().messageName().getText() == null ? null : + lookupHelper.getPackageFieldMessageType(fileType, typeContext); + } + /** * Field type enum for use in field classes */ enum FieldType { /** Protobuf message field type */ - MESSAGE("Object", "null", TYPE_LENGTH_DELIMITED), + MESSAGE("Object", "Object", "null", TYPE_LENGTH_DELIMITED), /** Protobuf enum(unsigned varint encoded int of ordinal) field type */ - ENUM("int", "null", TYPE_VARINT), + ENUM("int", "Integer", "null", TYPE_VARINT), /** Protobuf int32(signed varint encoded int) field type */ - INT32("int", "0", TYPE_VARINT), + INT32("int", "Integer", "0", TYPE_VARINT), /** Protobuf uint32(unsigned varint encoded int) field type */ - UINT32("int", "0", TYPE_VARINT), + UINT32("int", "Integer", "0", TYPE_VARINT), /** Protobuf sint32(signed zigzag varint encoded int) field type */ - SINT32("int", "0", TYPE_VARINT), + SINT32("int", "Integer", "0", TYPE_VARINT), /** Protobuf int64(signed varint encoded long) field type */ - INT64("long", "0", TYPE_VARINT), + INT64("long", "Long", "0", TYPE_VARINT), /** Protobuf uint64(unsigned varint encoded long) field type */ - UINT64("long", "0", TYPE_VARINT), + UINT64("long", "Long", "0", TYPE_VARINT), /** Protobuf sint64(signed zigzag varint encoded long) field type */ - SINT64("long", "0", TYPE_VARINT), + SINT64("long", "Long", "0", TYPE_VARINT), /** Protobuf float field type */ - FLOAT("float", "0", TYPE_FIXED32), + FLOAT("float", "Float", "0", TYPE_FIXED32), /** Protobuf fixed int32(fixed encoding int) field type */ - FIXED32("int", "0", TYPE_FIXED32), + FIXED32("int", "Integer", "0", TYPE_FIXED32), /** Protobuf sfixed int32(signed fixed encoding int) field type */ - SFIXED32("int", "0", TYPE_FIXED32), + SFIXED32("int", "Integer", "0", TYPE_FIXED32), /** Protobuf double field type */ - DOUBLE("double", "0", TYPE_FIXED64), + DOUBLE("double", "Double", "0", TYPE_FIXED64), /** Protobuf sfixed64(fixed encoding long) field type */ - FIXED64("long", "0", TYPE_FIXED64), + FIXED64("long", "Long", "0", TYPE_FIXED64), /** Protobuf sfixed64(signed fixed encoding long) field type */ - SFIXED64("long", "0", TYPE_FIXED64), + SFIXED64("long", "Long", "0", TYPE_FIXED64), /** Protobuf string field type */ - STRING("String", "\"\"", TYPE_LENGTH_DELIMITED), + STRING("String", "String", "\"\"", TYPE_LENGTH_DELIMITED), /** Protobuf bool(boolean) field type */ - BOOL("boolean", "false", TYPE_VARINT), + BOOL("boolean", "Boolean", "false", TYPE_VARINT), /** Protobuf bytes field type */ - BYTES("Bytes", "Bytes.EMPTY", TYPE_LENGTH_DELIMITED), + BYTES("Bytes", "Bytes", "Bytes.EMPTY", TYPE_LENGTH_DELIMITED), /** Protobuf oneof field type, this is not a true field type in protobuf. Needed here for a few edge cases */ - ONE_OF("OneOf", "null", 0 );// BAD TYPE + ONE_OF("OneOf", "OneOf", "null", 0 ),// BAD TYPE + // On the wire, a map is a repeated Message {key, value}, sorted in the natural order of keys for determinism. + MAP("Map", "Map", "Collections.EMPTY_MAP", TYPE_LENGTH_DELIMITED ); /** The type of field type in Java code */ public final String javaType; + /** The type of boxed field type in Java code */ + public final String boxedType; /** The field type default value in Java code */ public final String javaDefault; /** The protobuf wire type for field type */ @@ -241,11 +265,13 @@ enum FieldType { * Construct a new FieldType enum * * @param javaType The type of field type in Java code + * @param boxedType The boxed type of the field type, e.g. Integer for an int field. * @param javaDefault The field type default value in Java code * @param wireType The protobuf wire type for field type */ - FieldType(String javaType, final String javaDefault, int wireType) { + FieldType(String javaType, final String boxedType, final String javaDefault, int wireType) { this.javaType = javaType; + this.boxedType = boxedType; this.javaDefault = javaDefault; this.wireType = wireType; } @@ -337,5 +363,42 @@ static FieldType of(Protobuf3Parser.Type_Context typeContext, final ContextualL throw new IllegalArgumentException("Unknown field type: "+typeContext); } } + + /** + * Get the field type for a given map key type parser context + * + * @param typeContext The parser context to get field type for + * @param lookupHelper Lookup helper with global context + * @return The field type enum for parser context + */ + static FieldType of(Protobuf3Parser.KeyTypeContext typeContext, final ContextualLookupHelper lookupHelper) { + if (typeContext.INT32() != null) { + return FieldType.INT32; + } else if (typeContext.UINT32() != null) { + return FieldType.UINT32; + } else if (typeContext.SINT32() != null) { + return FieldType.SINT32; + } else if (typeContext.INT64() != null) { + return FieldType.INT64; + } else if (typeContext.UINT64() != null) { + return FieldType.UINT64; + } else if (typeContext.SINT64() != null) { + return FieldType.SINT64; + } else if (typeContext.FIXED32() != null) { + return FieldType.FIXED32; + } else if (typeContext.SFIXED32() != null) { + return FieldType.SFIXED32; + } else if (typeContext.FIXED64() != null) { + return FieldType.FIXED64; + } else if (typeContext.SFIXED64() != null) { + return FieldType.SFIXED64; + } else if (typeContext.STRING() != null) { + return FieldType.STRING; + } else if (typeContext.BOOL() != null) { + return FieldType.BOOL; + } else { + throw new IllegalArgumentException("Unknown map key type: " + typeContext); + } + } } } diff --git a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/MapField.java b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/MapField.java new file mode 100644 index 00000000..e092a918 --- /dev/null +++ b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/MapField.java @@ -0,0 +1,150 @@ +package com.hedera.pbj.compiler.impl; + +import java.util.Set; +import com.hedera.pbj.compiler.impl.grammar.Protobuf3Parser; +import static com.hedera.pbj.compiler.impl.SingleField.getDeprecatedOption; + +/** + * A field of type map. + *

+ * In protobuf, a map is essentially a repeated map entry message with two fields: key and value. + * However, we don't model the map entry message explicitly for performance reasons. Instead, + * we deal with the keys and values directly, and define synthetic Field objects for them here + * for convenience, so that we can reuse the majority of the code generation code. + *

+ * In model implementations we use a custom implementation of the Map interface named PbjMap + * which is an immutable map that exposes a SortedKeys list which allows one to iterate + * the map deterministically which is useful for serializing, computing the hash code, etc. + */ +public record MapField( + /** A synthetic "key" field in a map entry. */ + Field keyField, + /** A synthetic "value" field in a map entry. */ + Field valueField, + // The rest of the fields below simply implement the Field interface: + boolean repeated, + int fieldNumber, + String name, + FieldType type, + String protobufFieldType, + String javaFieldTypeBase, + String methodNameType, + String parseCode, + String javaDefault, + String parserFieldsSetMethodCase, + String comment, + boolean deprecated +) implements Field { + + /** + * Construct a MapField instance out of a MapFieldContext and a lookup helper. + */ + public MapField(Protobuf3Parser.MapFieldContext mapContext, final ContextualLookupHelper lookupHelper) { + this( + new SingleField( + false, + FieldType.of(mapContext.keyType(), lookupHelper), + 1, + "___" + mapContext.mapName().getText() + "__key", + null, + null, + null, + null, + "An internal, private map entry key for " + mapContext.mapName().getText(), + false, + null), + new SingleField( + false, + FieldType.of(mapContext.type_(), lookupHelper), + 2, + "___" + mapContext.mapName().getText() + "__value", + Field.extractMessageTypeName(mapContext.type_()), + Field.extractMessageTypePackage(mapContext.type_(), FileType.MODEL, lookupHelper), + Field.extractMessageTypePackage(mapContext.type_(), FileType.CODEC, lookupHelper), + Field.extractMessageTypePackage(mapContext.type_(), FileType.TEST, lookupHelper), + "An internal, private map entry value for " + mapContext.mapName().getText(), + false, + null), + + false, // maps cannot be repeated + Integer.parseInt(mapContext.fieldNumber().getText()), + mapContext.mapName().getText(), + FieldType.MAP, + "", + "", + "", + null, + "PbjMap.EMPTY", + "", + Common.buildCleanFieldJavaDoc(Integer.parseInt(mapContext.fieldNumber().getText()), mapContext.docComment()), + getDeprecatedOption(mapContext.fieldOptions()) + ); + } + + /** + * Composes the Java generic type of the map field, e.g. "<Integer, String>" for a Map<Integer, String>. + */ + public String javaGenericType() { + return "<" + keyField.type().boxedType + ", " + + (valueField().type() == FieldType.MESSAGE ? ((SingleField)valueField()).messageType() : valueField().type().boxedType) + + ">"; + } + + /** + * {@inheritDoc} + */ + @Override + public String javaFieldType() { + return "Map" + javaGenericType(); + } + + private void composeFieldDef(StringBuilder sb, Field field) { + sb.append(""" + /** + * $doc + */ + """ + .replace("$doc", field.comment().replaceAll("\n","\n * ")) + ); + sb.append(" public static final FieldDefinition %s = new FieldDefinition(\"%s\", FieldType.%s, %s, false, false, %d);\n" + .formatted(Common.camelToUpperSnake(field.name()), field.name(), field.type().fieldType(), field.repeated(), field.fieldNumber())); + } + + /** + * {@inheritDoc} + */ + @Override + public String schemaFieldsDef() { + StringBuilder sb = new StringBuilder(); + composeFieldDef(sb, this); + composeFieldDef(sb, keyField); + composeFieldDef(sb, valueField); + return sb.toString(); + } + + /** + * {@inheritDoc} + */ + @Override + public String schemaGetFieldsDefCase() { + return "case %d -> %s;".formatted(fieldNumber, Common.camelToUpperSnake(name)); + } + + /** + * {@inheritDoc} + */ + @Override + public void addAllNeededImports( + final Set imports, + final boolean modelImports, + final boolean codecImports, + final boolean testImports) { + if (modelImports) { + imports.add("java.util"); + } + if (codecImports) { + imports.add("java.util.stream"); + imports.add("com.hedera.pbj.runtime.test"); + } + } +} diff --git a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/SingleField.java b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/SingleField.java index 68e7cd22..28e09f00 100644 --- a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/SingleField.java +++ b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/SingleField.java @@ -31,14 +31,12 @@ public record SingleField(boolean repeated, FieldType type, int fieldNumber, Str public SingleField(Protobuf3Parser.FieldContext fieldContext, final ContextualLookupHelper lookupHelper) { this(fieldContext.REPEATED() != null, FieldType.of(fieldContext.type_(), lookupHelper), - Integer.parseInt(fieldContext.fieldNumber().getText()), fieldContext.fieldName().getText(), - (fieldContext.type_().messageType() == null) ? null : - fieldContext.type_().messageType().messageName().getText(), - (fieldContext.type_().messageType() == null || fieldContext.type_().messageType().messageName().getText() == null) ? null : - lookupHelper.getPackageFieldMessageType(FileType.MODEL, fieldContext), - (fieldContext.type_().messageType() == null || fieldContext.type_().messageType().messageName().getText() == null) ? null : - lookupHelper.getPackageFieldMessageType(FileType.CODEC, fieldContext), (fieldContext.type_().messageType() == null || fieldContext.type_().messageType().messageName().getText() == null) ? null : - lookupHelper.getPackageFieldMessageType(FileType.TEST, fieldContext), + Integer.parseInt(fieldContext.fieldNumber().getText()), + fieldContext.fieldName().getText(), + Field.extractMessageTypeName(fieldContext.type_()), + Field.extractMessageTypePackage(fieldContext.type_(), FileType.MODEL, lookupHelper), + Field.extractMessageTypePackage(fieldContext.type_(), FileType.CODEC, lookupHelper), + Field.extractMessageTypePackage(fieldContext.type_(), FileType.TEST, lookupHelper), Common.buildCleanFieldJavaDoc(Integer.parseInt(fieldContext.fieldNumber().getText()), fieldContext.docComment()), getDeprecatedOption(fieldContext.fieldOptions()), null @@ -320,13 +318,13 @@ public String parserFieldsSetMethodCase() { * @param optionContext protobuf options from parser * @return true if field has deprecated option, otherwise false */ - private static boolean getDeprecatedOption(Protobuf3Parser.FieldOptionsContext optionContext) { + static boolean getDeprecatedOption(Protobuf3Parser.FieldOptionsContext optionContext) { if (optionContext != null) { for (var option : optionContext.fieldOption()) { if ("deprecated".equals(option.optionName().getText())) { return true; } else { - System.err.println("Unhandled Option on enum: "+optionContext.getText()); + System.err.println("Unhandled Option: " + optionContext.getText()); } } } diff --git a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/ModelGenerator.java b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/ModelGenerator.java index 0762bd61..a51d20fb 100644 --- a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/ModelGenerator.java +++ b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/ModelGenerator.java @@ -15,6 +15,7 @@ import com.hedera.pbj.compiler.impl.Field; import com.hedera.pbj.compiler.impl.Field.FieldType; import com.hedera.pbj.compiler.impl.FileType; +import com.hedera.pbj.compiler.impl.MapField; import com.hedera.pbj.compiler.impl.OneOfField; import com.hedera.pbj.compiler.impl.SingleField; import com.hedera.pbj.compiler.impl.grammar.Protobuf3Parser; @@ -101,7 +102,9 @@ public void generate(final MessageDefContext msgDef, } else if (item.oneof() != null) { // process one ofs oneofGetters.addAll(generateCodeForOneOf(lookupHelper, item, javaRecordName, imports, oneofEnums, fields)); } else if (item.mapField() != null) { // process map fields - System.err.println("Encountered a mapField that was not handled in " + javaRecordName); + final MapField field = new MapField(item.mapField(), lookupHelper); + fields.add(field); + field.addAllNeededImports(imports, true, false, false); } else if (item.field() != null && item.field().fieldName() != null) { generateCodeForField(lookupHelper, item, fields, imports, hasMethods); } else if (item.optionStatement() != null){ @@ -406,6 +409,12 @@ private static String generateConstructor( ); break; } + case MAP: { + sb.append("this.$name = PbjMap.of($name);" + .replace("$name", field.nameCamelFirstLower()) + ); + break; + } default: if (field.repeated()) { sb.append("this.$name = $name == null ? Collections.emptyList() : $name;".replace("$name", field.nameCamelFirstLower())); diff --git a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/SchemaGenerator.java b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/SchemaGenerator.java index 13d1ddd8..5f9ae52d 100644 --- a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/SchemaGenerator.java +++ b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/SchemaGenerator.java @@ -37,7 +37,9 @@ public void generate(final Protobuf3Parser.MessageDefContext msgDef, final File fields.add(field); field.addAllNeededImports(imports, true, false, false); } else if (item.mapField() != null) { // process map flattenedFields - throw new IllegalStateException("Encountered a mapField that was not handled in parser"); + final MapField field = new MapField(item.mapField(), lookupHelper); + fields.add(field); + field.addAllNeededImports(imports, true, false, false); } else if (item.field() != null && item.field().fieldName() != null) { final var field = new SingleField(item.field(), lookupHelper); fields.add(field); diff --git a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/TestGenerator.java b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/TestGenerator.java index a2e49bdb..46c6f8e1 100644 --- a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/TestGenerator.java +++ b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/TestGenerator.java @@ -11,6 +11,7 @@ import java.io.IOException; import java.util.ArrayList; import java.util.List; +import java.util.Map; import java.util.Set; import java.util.TreeSet; import java.util.stream.Collectors; @@ -46,7 +47,9 @@ public void generate(Protobuf3Parser.MessageDefContext msgDef, File destinationS subField.addAllNeededImports(imports, true, false, true); } } else if (item.mapField() != null) { // process map fields - throw new IllegalStateException("Encountered a mapField that was not handled in "+ modelClassName); + final MapField field = new MapField(item.mapField(), lookupHelper); + fields.add(field); + field.addAllNeededImports(imports, true, false, true); } else if (item.field() != null && item.field().fieldName() != null) { final var field = new SingleField(item.field(), lookupHelper); fields.add(field); @@ -174,12 +177,12 @@ private static String generateTestData(String modelClassName, Field field, boole generateListArguments(%s)""".formatted(optionsList); } else if (field instanceof final OneOfField oneOf) { final List options = new ArrayList<>(); - for (var subField: oneOf.fields()) { + for (var subField : oneOf.fields()) { if (subField instanceof SingleField) { final String enumValueName = Common.camelToUpperSnake(subField.name()); // special cases to break cyclic dependencies if (!("THRESHOLD_KEY".equals(enumValueName) || "KEY_LIST".equals(enumValueName) - || "THRESHOLD_SIGNATURE".equals(enumValueName)|| "SIGNATURE_LIST".equals(enumValueName))) { + || "THRESHOLD_SIGNATURE".equals(enumValueName) || "SIGNATURE_LIST".equals(enumValueName))) { final String listStr; if (subField.optionalValueType()) { Field.FieldType convertedSubFieldType = getOptionalConvertedFieldType(subField); @@ -187,19 +190,19 @@ private static String generateTestData(String modelClassName, Field field, boole } else { listStr = getOptionsForFieldType(subField.type(), ((SingleField) subField).javaFieldTypeForTest()); } - options.add(listStr + ("\n.stream()\n"+ - """ + options.add(listStr + ("\n.stream()\n" + + """ .map(value -> new %s<>(%sOneOfType.%s, value)) .toList()""".formatted( ((OneOfField) field).className(), modelClassName + "." + field.nameCamelFirstUpper(), enumValueName - )).indent(DEFAULT_INDENT ) + )).indent(DEFAULT_INDENT) ); } } else { - System.err.println("Did not expect a OneOfField in a OneOfField. In "+ - "modelClassName="+modelClassName+" field="+field+" subField="+subField); + System.err.println("Did not expect a OneOfField in a OneOfField. In " + + "modelClassName=" + modelClassName + " field=" + field + " subField=" + subField); } } return """ @@ -207,10 +210,36 @@ private static String generateTestData(String modelClassName, Field field, boole List.of(new %s<>(%sOneOfType.UNSET, null)), %s ).flatMap(List::stream).toList()""".formatted( - ((OneOfField)field).className(), - modelClassName+"."+field.nameCamelFirstUpper(), - String.join(",\n", options).indent(DEFAULT_INDENT) + ((OneOfField) field).className(), + modelClassName + "." + field.nameCamelFirstUpper(), + String.join(",\n", options).indent(DEFAULT_INDENT) ).indent(DEFAULT_INDENT * 2); + } else if (field instanceof final MapField mapField) { + // e.g. INTEGER_TESTS_LIST + final String keyOptions = getOptionsForFieldType(mapField.keyField().type(), mapField.keyField().javaFieldType()); + // e.g. STRING_TESTS_LIST, or, say, CustomMessageTest.ARGUMENTS + final String valueOptions = getOptionsForFieldType(mapField.valueField().type(), mapField.valueField().javaFieldType()); + + // A cartesian product is nice to use, but it doesn't seem reasonable from the performance perspective. + // Instead, we want to test three cases: + // 1. Empty map + // 2. Map with a single entry + // 3. Map with multiple (e.g. two) entries + // Note that keys and value options lists may be pretty small. E.g. Boolean would only have 2 elements. So we use mod. + // Also note that we assume there's at least one element in each list. + return """ + List.of( + Map.$javaGenericTypeof(), + Map.$javaGenericTypeof($keyOptions.get(0), $valueOptions.get(0)), + Map.$javaGenericTypeof( + $keyOptions.get(1 % $keyOptions.size()), $valueOptions.get(1 % $valueOptions.size()), + $keyOptions.get(2 % $keyOptions.size()), $valueOptions.get(2 % $valueOptions.size()) + ) + )""" + .replace("$javaGenericType", mapField.javaGenericType()) + .replace("$keyOptions", keyOptions) + .replace("$valueOptions", valueOptions) + ; } else { return getOptionsForFieldType(field.type(), ((SingleField)field).javaFieldTypeForTest()); } @@ -245,6 +274,7 @@ private static String getOptionsForFieldType(Field.FieldType fieldType, String j case ENUM -> "Arrays.asList(" + javaFieldType + ".values())"; case ONE_OF -> throw new RuntimeException("Should never happen, should have been caught in generateTestData()"); case MESSAGE -> javaFieldType + FileAndPackageNamesConfig.TEST_JAVA_FILE_SUFFIX + ".ARGUMENTS"; + case MAP -> throw new RuntimeException("Should never happen, should have been caught in generateTestData()"); }; } diff --git a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/json/JsonCodecGenerator.java b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/json/JsonCodecGenerator.java index dbbc935f..74fbda6f 100644 --- a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/json/JsonCodecGenerator.java +++ b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/json/JsonCodecGenerator.java @@ -42,7 +42,9 @@ public void generate(Protobuf3Parser.MessageDefContext msgDef, final File destin fields.add(field); field.addAllNeededImports(imports, true, true, false); } else if (item.mapField() != null) { // process map fields - throw new IllegalStateException("Encountered a mapField that was not handled in "+ codecClassName); + final MapField field = new MapField(item.mapField(), lookupHelper); + fields.add(field); + field.addAllNeededImports(imports, true, true, false); } else if (item.field() != null && item.field().fieldName() != null) { final var field = new SingleField(item.field(), lookupHelper); fields.add(field); diff --git a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/json/JsonCodecParseMethodGenerator.java b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/json/JsonCodecParseMethodGenerator.java index 0285b05e..dfcde8c1 100644 --- a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/json/JsonCodecParseMethodGenerator.java +++ b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/json/JsonCodecParseMethodGenerator.java @@ -2,6 +2,7 @@ import com.hedera.pbj.compiler.impl.Common; import com.hedera.pbj.compiler.impl.Field; +import com.hedera.pbj.compiler.impl.MapField; import com.hedera.pbj.compiler.impl.OneOfField; import java.util.List; @@ -67,7 +68,7 @@ static String generateParseObjectMethod(final String modelClassName, final List< for (JSONParser.PairContext kvPair : root.pair()) { switch (kvPair.STRING().getText()) { $caseStatements - default -> { + default: { if (strictMode) { // Since we are parsing is strict mode, this is an exceptional condition. throw new UnknownFieldException(kvPair.STRING().getText()); @@ -103,34 +104,36 @@ private static String generateCaseStatements(final List fields) { if (field instanceof final OneOfField oneOfField) { for(final Field subField: oneOfField.fields()) { sb.append("case \"" + toJsonFieldName(subField.name()) +"\" /* [" + subField.fieldNumber() + "] */ " + - "-> temp_" + oneOfField.name() + " = new %s<>(\n".formatted(oneOfField.className()) + + ": temp_" + oneOfField.name() + " = new %s<>(\n".formatted(oneOfField.className()) + oneOfField.getEnumClassRef().indent(DEFAULT_INDENT) +"."+Common.camelToUpperSnake(subField.name())+ ", \n".indent(DEFAULT_INDENT)); - generateFieldCaseStatement(sb,subField); - sb.append(");\n"); + generateFieldCaseStatement(sb, subField, "kvPair.value()"); + sb.append("); break;\n"); } } else { sb.append("case \"" + toJsonFieldName(field.name()) +"\" /* [" + field.fieldNumber() + "] */ " + - "-> temp_" + field.name()+" = "); - generateFieldCaseStatement(sb, field); - sb.append(";\n"); + ": temp_" + field.name()+" = "); + generateFieldCaseStatement(sb, field, "kvPair.value()"); + sb.append("; break;\n"); } } - return sb.toString().indent(DEFAULT_INDENT * 3); + return sb.toString(); } /** * Generate switch case statement for a field. * * @param field field to generate case statement for - * @param sb StringBuilder to append code to + * @param origSB StringBuilder to append code to + * @param valueGetter normally a "kvPair.value()", but may be different e.g. for maps parsing */ - private static void generateFieldCaseStatement(final StringBuilder sb, final Field field) { + private static void generateFieldCaseStatement(final StringBuilder origSB, final Field field, final String valueGetter) { + final StringBuilder sb = new StringBuilder(); if (field.repeated()) { if (field.type() == Field.FieldType.MESSAGE) { - sb.append("parseObjArray(kvPair.value().arr(), "+field.messageType()+".JSON, maxDepth - 1)"); + sb.append("parseObjArray($valueGetter.arr(), "+field.messageType()+".JSON, maxDepth - 1)"); } else { - sb.append("kvPair.value().arr().value().stream().map(v -> "); + sb.append("$valueGetter.arr().value().stream().map(v -> "); switch (field.type()) { case ENUM -> sb.append(field.messageType() + ".fromString(v.STRING().getText())"); case INT32, UINT32, SINT32, FIXED32, SFIXED32 -> sb.append("parseInteger(v)"); @@ -145,29 +148,48 @@ private static void generateFieldCaseStatement(final StringBuilder sb, final Fie sb.append(").toList()"); } } else if (field.optionalValueType()) { - switch(field.messageType()) { - case "Int32Value", "UInt32Value" -> sb.append("parseInteger(kvPair.value())"); - case "Int64Value", "UInt64Value" -> sb.append("parseLong(kvPair.value())"); - case "FloatValue" -> sb.append("parseFloat(kvPair.value())"); - case "DoubleValue" -> sb.append("parseDouble(kvPair.value())"); - case "StringValue" -> sb.append("unescape(kvPair.value().STRING().getText())"); - case "BoolValue" -> sb.append("parseBoolean(kvPair.value())"); - case "BytesValue" -> sb.append("Bytes.fromBase64(kvPair.value().STRING().getText())"); - default -> throw new RuntimeException("Unknown message type ["+field.messageType()+"]"); + switch (field.messageType()) { + case "Int32Value", "UInt32Value" -> sb.append("parseInteger($valueGetter)"); + case "Int64Value", "UInt64Value" -> sb.append("parseLong($valueGetter)"); + case "FloatValue" -> sb.append("parseFloat($valueGetter)"); + case "DoubleValue" -> sb.append("parseDouble($valueGetter)"); + case "StringValue" -> sb.append("unescape($valueGetter.STRING().getText())"); + case "BoolValue" -> sb.append("parseBoolean($valueGetter)"); + case "BytesValue" -> sb.append("Bytes.fromBase64($valueGetter.STRING().getText())"); + default -> throw new RuntimeException("Unknown message type [" + field.messageType() + "]"); } + } else if (field.type() == Field.FieldType.MAP) { + final MapField mapField = (MapField) field; + + final StringBuilder keySB = new StringBuilder(); + final StringBuilder valueSB = new StringBuilder(); + + generateFieldCaseStatement(keySB, mapField.keyField(), "mapKV"); + generateFieldCaseStatement(valueSB, mapField.valueField(), "mapKV.value()"); + + sb.append(""" + $valueGetter.getChild(JSONParser.ObjContext.class, 0).pair().stream() + .collect(Collectors.toMap( + mapKV -> $mapEntryKey, + new UncheckedThrowingFunction<>(mapKV -> $mapEntryValue) + ))""" + .replace("$mapEntryKey", keySB.toString()) + .replace("$mapEntryValue", valueSB.toString()) + ); } else { switch (field.type()) { - case MESSAGE -> sb.append(field.javaFieldType() + ".JSON.parse(kvPair.value().getChild(JSONParser.ObjContext.class, 0), false, maxDepth - 1)"); - case ENUM -> sb.append(field.javaFieldType() + ".fromString(kvPair.value().STRING().getText())"); - case INT32, UINT32, SINT32, FIXED32, SFIXED32 -> sb.append("parseInteger(kvPair.value())"); - case INT64, UINT64, SINT64, FIXED64, SFIXED64 -> sb.append("parseLong(kvPair.value())"); - case FLOAT -> sb.append("parseFloat(kvPair.value())"); - case DOUBLE -> sb.append("parseDouble(kvPair.value())"); - case STRING -> sb.append("unescape(kvPair.value().STRING().getText())"); - case BOOL -> sb.append("parseBoolean(kvPair.value())"); - case BYTES -> sb.append("Bytes.fromBase64(kvPair.value().STRING().getText())"); + case MESSAGE -> sb.append(field.javaFieldType() + ".JSON.parse($valueGetter.getChild(JSONParser.ObjContext.class, 0), false, maxDepth - 1)"); + case ENUM -> sb.append(field.javaFieldType() + ".fromString($valueGetter.STRING().getText())"); + case INT32, UINT32, SINT32, FIXED32, SFIXED32 -> sb.append("parseInteger($valueGetter)"); + case INT64, UINT64, SINT64, FIXED64, SFIXED64 -> sb.append("parseLong($valueGetter)"); + case FLOAT -> sb.append("parseFloat($valueGetter)"); + case DOUBLE -> sb.append("parseDouble($valueGetter)"); + case STRING -> sb.append("unescape($valueGetter.STRING().getText())"); + case BOOL -> sb.append("parseBoolean($valueGetter)"); + case BYTES -> sb.append("Bytes.fromBase64($valueGetter.STRING().getText())"); default -> throw new RuntimeException("Unknown field type ["+field.type()+"]"); } } + origSB.append(sb.toString().replace("$valueGetter", valueGetter)); } } diff --git a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/json/JsonCodecWriteMethodGenerator.java b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/json/JsonCodecWriteMethodGenerator.java index ad1f8265..d8d6bf31 100644 --- a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/json/JsonCodecWriteMethodGenerator.java +++ b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/json/JsonCodecWriteMethodGenerator.java @@ -2,6 +2,7 @@ import com.hedera.pbj.compiler.impl.Common; import com.hedera.pbj.compiler.impl.Field; +import com.hedera.pbj.compiler.impl.MapField; import com.hedera.pbj.compiler.impl.OneOfField; import com.hedera.pbj.compiler.impl.SingleField; import edu.umd.cs.findbugs.annotations.NonNull; @@ -75,7 +76,7 @@ public String toJSON(@NonNull $modelClass data, String indent, boolean inline) { private static String generateFieldWriteLines(final Field field, final String modelClassName, String getValueCode) { final String fieldDef = Common.camelToUpperSnake(field.name()); final String fieldName = '\"' + toJsonFieldName(field.name()) + '\"'; - final String basicFieldCode = generateBasicFieldLines(field, getValueCode, fieldDef, fieldName); + final String basicFieldCode = generateBasicFieldLines(field, getValueCode, fieldDef, fieldName, "childIndent"); String prefix = "// ["+field.fieldNumber()+"] - "+field.name() + "\n"; if (field.parent() != null) { @@ -88,10 +89,13 @@ private static String generateFieldWriteLines(final Field field, final String mo } else { if (field.repeated()) { return prefix + "if (!data." + field.nameCamelFirstLower() + "().isEmpty()) fieldLines.add(" + basicFieldCode + ");"; - } else if (field.type() == Field.FieldType.BYTES){ + } else if (field.type() == Field.FieldType.BYTES) { return prefix + "if (data." + field.nameCamelFirstLower() + "() != " + field.javaDefault() + " && data." + field.nameCamelFirstLower() + "() != null" + " && data." + field.nameCamelFirstLower() + "().length() > 0) fieldLines.add(" + basicFieldCode + ");"; + } else if (field.type() == Field.FieldType.MAP) { + return prefix + "if (data." + field.nameCamelFirstLower() + "() != " + field.javaDefault() + + " && !data." + field.nameCamelFirstLower() + "().isEmpty()) fieldLines.add(" + basicFieldCode + ");"; } else { return prefix + "if (data." + field.nameCamelFirstLower() + "() != " + field.javaDefault() + ") fieldLines.add(" + basicFieldCode + ");"; } @@ -99,7 +103,7 @@ private static String generateFieldWriteLines(final Field field, final String mo } @NonNull - private static String generateBasicFieldLines(Field field, String getValueCode, String fieldDef, String fieldName) { + private static String generateBasicFieldLines(Field field, String getValueCode, String fieldDef, String fieldName, String childIndent) { if (field.optionalValueType()) { return switch (field.messageType()) { case "StringValue", "BoolValue", "Int32Value", @@ -123,13 +127,29 @@ private static String generateBasicFieldLines(Field field, String getValueCode, .replace("$fieldDef", fieldDef) .replace("$valueCode", getValueCode); }; + } else if (field.type() == Field.FieldType.MAP) { + final MapField mapField = (MapField) field; + final String vComposerMethod = generateBasicFieldLines( + mapField.valueField(), + "v", + Common.camelToUpperSnake(mapField.valueField().name()), + "n", + "indent" + ); + return "field(%s, %s, $kEncoder, $vComposer)" + .formatted(fieldName, getValueCode) + // Maps in protobuf can only have simple scalar and not floating keys, so toString should do a good job. + // Also see https://protobuf.dev/programming-guides/proto3/#json + .replace("$kEncoder", "k -> escape(k.toString())") + .replace("$vComposer", "(n, v) -> " + vComposerMethod); } else { return switch (field.type()) { case ENUM -> "field($fieldName, $valueCode.protoName())" .replace("$fieldName", fieldName) .replace("$fieldDef", fieldDef) .replace("$valueCode", getValueCode); - case MESSAGE -> "field(childIndent, $fieldName, $codec, $valueCode)" + case MESSAGE -> "field($childIndent, $fieldName, $codec, $valueCode)" + .replace("$childIndent", childIndent) .replace("$fieldName", fieldName) .replace("$fieldDef", fieldDef) .replace("$valueCode", getValueCode) diff --git a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/protobuf/CodecGenerator.java b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/protobuf/CodecGenerator.java index 6bf49d64..5e0fb10d 100644 --- a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/protobuf/CodecGenerator.java +++ b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/protobuf/CodecGenerator.java @@ -39,7 +39,9 @@ public void generate(Protobuf3Parser.MessageDefContext msgDef, final File destin fields.add(field); field.addAllNeededImports(imports, true, true, false); } else if (item.mapField() != null) { // process map fields - throw new IllegalStateException("Encountered a mapField that was not handled in "+ codecClassName); + final MapField field = new MapField(item.mapField(), lookupHelper); + fields.add(field); + field.addAllNeededImports(imports, true, true, false); } else if (item.field() != null && item.field().fieldName() != null) { final var field = new SingleField(item.field(), lookupHelper); fields.add(field); @@ -69,7 +71,7 @@ public void generate(Protobuf3Parser.MessageDefContext msgDef, final File destin import static $schemaClass.*; import static com.hedera.pbj.runtime.ProtoWriterTools.*; import static com.hedera.pbj.runtime.ProtoParserTools.*; - import static com.hedera.pbj.runtime.ProtoConstants.TAG_WIRE_TYPE_MASK; + import static com.hedera.pbj.runtime.ProtoConstants.*; /** * Protobuf Codec for $modelClass model object. Generated based on protobuf schema. diff --git a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/protobuf/CodecMeasureRecordMethodGenerator.java b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/protobuf/CodecMeasureRecordMethodGenerator.java index 8b2cb588..622aaaa6 100644 --- a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/protobuf/CodecMeasureRecordMethodGenerator.java +++ b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/protobuf/CodecMeasureRecordMethodGenerator.java @@ -4,11 +4,13 @@ import com.hedera.pbj.compiler.impl.Common; import com.hedera.pbj.compiler.impl.Field; +import com.hedera.pbj.compiler.impl.MapField; import com.hedera.pbj.compiler.impl.OneOfField; import com.hedera.pbj.compiler.impl.SingleField; import java.util.Comparator; import java.util.List; +import java.util.function.Function; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -19,11 +21,11 @@ class CodecMeasureRecordMethodGenerator { static String generateMeasureMethod(final String modelClassName, final List fields) { - final String fieldSizeOfLines = fields.stream() - .flatMap(field -> field.type() == Field.FieldType.ONE_OF ? ((OneOfField)field).fields().stream() : Stream.of(field)) - .sorted(Comparator.comparingInt(Field::fieldNumber)) - .map(field -> generateFieldSizeOfLines(field, modelClassName, "data.%s()".formatted(field.nameCamelFirstLower()))) - .collect(Collectors.joining("\n")).indent(DEFAULT_INDENT); + final String fieldSizeOfLines = buildFieldSizeOfLines( + modelClassName, + fields, + field -> "data.%s()".formatted(field.nameCamelFirstLower()), + true); return """ /** * Compute number of bytes that would be written when calling {@code write()} method. @@ -42,15 +44,28 @@ public int measureRecord($modelClass data) { .indent(DEFAULT_INDENT); } + static String buildFieldSizeOfLines( + final String modelClassName, + final List fields, + final Function getValueBuilder, + boolean skipDefault) { + return fields.stream() + .flatMap(field -> field.type() == Field.FieldType.ONE_OF ? ((OneOfField)field).fields().stream() : Stream.of(field)) + .sorted(Comparator.comparingInt(Field::fieldNumber)) + .map(field -> generateFieldSizeOfLines(field, modelClassName, getValueBuilder.apply(field), skipDefault)) + .collect(Collectors.joining("\n")).indent(DEFAULT_INDENT); + } + /** * Generate lines of code for measure method, that measure the size of each field and add to "size" variable. * * @param field The field to generate size of line * @param modelClassName The model class name for model class for message type we are generating writer for * @param getValueCode java code to get the value of field + * @param skipDefault true if default value of the field should result in size zero * @return java code for adding fields size to "size" variable */ - private static String generateFieldSizeOfLines(final Field field, final String modelClassName, String getValueCode) { + private static String generateFieldSizeOfLines(final Field field, final String modelClassName, String getValueCode, boolean skipDefault) { final String fieldDef = Common.camelToUpperSnake(field.name()); String prefix = "// ["+field.fieldNumber()+"] - "+field.name(); prefix += "\n"; @@ -84,30 +99,63 @@ private static String generateFieldSizeOfLines(final Field field, final String m default -> throw new UnsupportedOperationException("Unhandled optional message type:"+field.messageType()); }; } else if (field.repeated()) { - return prefix + switch(field.type()) { + return prefix + switch (field.type()) { case ENUM -> "size += sizeOfEnumList(%s, %s);" .formatted(fieldDef, getValueCode); case MESSAGE -> "size += sizeOfMessageList($fieldDef, $valueCode, $codec::measureRecord);" .replace("$fieldDef", fieldDef) .replace("$valueCode", getValueCode) - .replace("$codec", ((SingleField)field).messageTypeModelPackage() + "." + - Common.capitalizeFirstLetter(field.messageType())+ ".PROTOBUF"); + .replace("$codec", ((SingleField) field).messageTypeModelPackage() + "." + + Common.capitalizeFirstLetter(field.messageType()) + ".PROTOBUF"); default -> "size += sizeOf%sList(%s, %s);" .formatted(writeMethodName, fieldDef, getValueCode); }; + } else if (field.type() == Field.FieldType.MAP) { + final MapField mapField = (MapField) field; + final List mapEntryFields = List.of(mapField.keyField(), mapField.valueField()); + final Function getValueBuilder = mapEntryField -> + mapEntryField == mapField.keyField() ? "k" : (mapEntryField == mapField.valueField() ? "v" : null); + final String fieldSizeOfLines = CodecMeasureRecordMethodGenerator.buildFieldSizeOfLines( + field.name(), + mapEntryFields, + getValueBuilder, + false); + return prefix + """ + if (!$map.isEmpty()) { + final Pbj$javaFieldType pbjMap = (Pbj$javaFieldType) $map; + final int mapSize = pbjMap.size(); + for (int i = 0; i < mapSize; i++) { + size += sizeOfTag($fieldDef, WIRE_TYPE_DELIMITED); + final int sizePre = size; + $K k = pbjMap.getSortedKeys().get(i); + $V v = pbjMap.get(k); + $fieldSizeOfLines + size += sizeOfVarInt32(size - sizePre); + } + } + """ + .replace("$fieldDef", fieldDef) + .replace("$map", getValueCode) + .replace("$javaFieldType", mapField.javaFieldType()) + .replace("$K", mapField.keyField().type().boxedType) + .replace("$V", mapField.valueField().type() == Field.FieldType.MESSAGE ? ((SingleField)mapField.valueField()).messageType() : mapField.valueField().type().boxedType) + .replace("$fieldSizeOfLines", fieldSizeOfLines.indent(DEFAULT_INDENT)) + ; } else { return prefix + switch(field.type()) { case ENUM -> "size += sizeOfEnum(%s, %s);" .formatted(fieldDef, getValueCode); - case STRING -> "size += sizeOfString(%s, %s);" - .formatted(fieldDef,getValueCode); + case STRING -> "size += sizeOfString(%s, %s, %s);" + .formatted(fieldDef, getValueCode, skipDefault); case MESSAGE -> "size += sizeOfMessage($fieldDef, $valueCode, $codec::measureRecord);" .replace("$fieldDef", fieldDef) .replace("$valueCode", getValueCode) .replace("$codec", ((SingleField)field).messageTypeModelPackage() + "." + Common.capitalizeFirstLetter(field.messageType())+ ".PROTOBUF"); - case BOOL -> "size += sizeOfBoolean(%s, %s);" - .formatted(fieldDef,getValueCode); + case BOOL -> "size += sizeOfBoolean(%s, %s, %s);" + .formatted(fieldDef, getValueCode, skipDefault); + case INT32, UINT32, SINT32, FIXED32, SFIXED32, INT64, SINT64, UINT64, FIXED64, SFIXED64, BYTES -> "size += sizeOf%s(%s, %s, %s);" + .formatted(writeMethodName, fieldDef, getValueCode, skipDefault); default -> "size += sizeOf%s(%s, %s);" .formatted(writeMethodName, fieldDef, getValueCode); }; diff --git a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/protobuf/CodecParseMethodGenerator.java b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/protobuf/CodecParseMethodGenerator.java index b544d984..284d8c00 100644 --- a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/protobuf/CodecParseMethodGenerator.java +++ b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/protobuf/CodecParseMethodGenerator.java @@ -4,8 +4,10 @@ import com.hedera.pbj.compiler.impl.Common; import com.hedera.pbj.compiler.impl.Field; +import com.hedera.pbj.compiler.impl.MapField; import com.hedera.pbj.compiler.impl.OneOfField; import com.hedera.pbj.compiler.impl.PbjCompilerException; +import edu.umd.cs.findbugs.annotations.NonNull; import java.util.List; import java.util.stream.Collectors; @@ -66,6 +68,28 @@ static String generateParseMethod(final String modelClassName, final List // -- TEMP STATE FIELDS -------------------------------------- $fieldDefs + $parseLoop + return new $modelClassName($fieldsList); + } catch (final Exception anyException) { + if (anyException instanceof ParseException parseException) { + throw parseException; + } + throw new ParseException(anyException); + } + } + """ + .replace("$modelClassName",modelClassName) + .replace("$fieldDefs",fields.stream().map(field -> " %s temp_%s = %s;".formatted(field.javaFieldType(), + field.name(), field.javaDefault())).collect(Collectors.joining("\n"))) + .replace("$fieldsList",fields.stream().map(field -> "temp_"+field.name()).collect(Collectors.joining(", "))) + .replace("$parseLoop", generateParseLoop(generateCaseStatements(fields), "")) + .replace("$skipMaxSize", String.valueOf(Field.DEFAULT_MAX_SIZE)) + .indent(DEFAULT_INDENT); + } + + // prefix is pre-pended to variable names to support a nested parsing loop. + static String generateParseLoop(final String caseStatements, @NonNull final String prefix) { + return """ // -- PARSE LOOP --------------------------------------------- // Continue to parse bytes out of the input stream until we get to the end. while (input.hasRemaining()) { @@ -74,33 +98,33 @@ static String generateParseMethod(final String modelClassName, final List // So we catch this exception here and **only** here, because an EOFException // anywhere else suggests that we're processing malformed data and so // we must re-throw the exception then. - final int tag; + final int $prefixtag; try { // Read the "tag" byte which gives us the field number for the next field to read // and the wire type (way it is encoded on the wire). - tag = input.readVarInt(false); + $prefixtag = input.readVarInt(false); } catch (EOFException e) { // There's no more fields. Stop the parsing loop. break; } // The field is the top 5 bits of the byte. Read this off - final int field = tag >>> TAG_FIELD_OFFSET; + final int $prefixfield = $prefixtag >>> TAG_FIELD_OFFSET; // Ask the Schema to inform us what field this represents. - final var f = getField(field); + final var $prefixf = getField($prefixfield); // Given the wire type and the field type, parse the field - switch (tag) { + switch ($prefixtag) { $caseStatements default -> { // The wire type is the bottom 3 bits of the byte. Read that off - final int wireType = tag & TAG_WIRE_TYPE_MASK; + final int wireType = $prefixtag & TAG_WIRE_TYPE_MASK; // handle error cases here, so we do not do if statements in normal loop // Validate the field number is valid (must be > 0) - if (field == 0) { + if ($prefixfield == 0) { throw new IOException("Bad protobuf encoding. We read a field value of " - + field); + + $prefixfield); } // Validate the wire type is valid (must be >=0 && <= 5). // Otherwise we cannot parse this. @@ -109,38 +133,27 @@ static String generateParseMethod(final String modelClassName, final List throw new IOException("Cannot understand wire_type of " + wireType); } // It may be that the parser subclass doesn't know about this field - if (f == null) { + if ($prefixf == null) { if (strictMode) { // Since we are parsing is strict mode, this is an exceptional condition. - throw new UnknownFieldException(field); + throw new UnknownFieldException($prefixfield); } else { // We just need to read off the bytes for this field to skip it // and move on to the next one. skipField(input, ProtoConstants.get(wireType), $skipMaxSize); } } else { - throw new IOException("Bad tag [" + tag + "], field [" + field + throw new IOException("Bad tag [" + $prefixtag + "], field [" + $prefixfield + "] wireType [" + wireType + "]"); } } } } - return new $modelClassName($fieldsList); - } catch (final Exception anyException) { - if (anyException instanceof ParseException parseException) { - throw parseException; - } - throw new ParseException(anyException); - } - } """ - .replace("$modelClassName",modelClassName) - .replace("$fieldDefs",fields.stream().map(field -> " %s temp_%s = %s;".formatted(field.javaFieldType(), - field.name(), field.javaDefault())).collect(Collectors.joining("\n"))) - .replace("$fieldsList",fields.stream().map(field -> "temp_"+field.name()).collect(Collectors.joining(", "))) - .replace("$caseStatements",generateCaseStatements(fields)) - .replace("$skipMaxSize", String.valueOf(Field.DEFAULT_MAX_SIZE)) - .indent(DEFAULT_INDENT); + .replace("$caseStatements",caseStatements) + .replace("$prefix",prefix) + .replace("$skipMaxSize", String.valueOf(Field.DEFAULT_MAX_SIZE)) + .indent(DEFAULT_INDENT); } /** @@ -303,6 +316,46 @@ private static void generateFieldCaseStatement(final StringBuilder sb, final Fie .replace("$maxSize", String.valueOf(field.maxSize())) .indent(DEFAULT_INDENT) ); + } else if (field.type() == Field.FieldType.MAP) { + // This is almost like reading a message above because that's how Protobuf encodes map entries. + // However(!), we read the key and value fields explicitly to avoid creating temporary entry objects. + final MapField mapField = (MapField) field; + final List mapEntryFields = List.of(mapField.keyField(), mapField.valueField()); + sb.append(""" + final var __map_messageLength = input.readVarInt(false); + + $fieldDefs + if (__map_messageLength != 0) { + if (__map_messageLength > $maxSize) { + throw new ParseException("$fieldName size " + __map_messageLength + " is greater than max " + $maxSize); + } + final var __map_limitBefore = input.limit(); + // Make sure that we have enough bytes in the message + // to read the subObject. + // If the buffer is truncated on the boundary of a subObject, + // we will not throw. + final var __map_startPos = input.position(); + try { + if ((__map_startPos + __map_messageLength) > __map_limitBefore) { + throw new BufferUnderflowException(); + } + input.limit(__map_startPos + __map_messageLength); + $mapParseLoop + // Make sure we read the full number of bytes. for the types + if ((__map_startPos + __map_messageLength) != input.position()) { + throw new BufferOverflowException(); + } + } finally { + input.limit(__map_limitBefore); + } + } + """ + .replace("$fieldName", field.name()) + .replace("$fieldDefs",mapEntryFields.stream().map(mapEntryField -> "%s temp_%s = %s;".formatted(mapEntryField.javaFieldType(), + mapEntryField.name(), mapEntryField.javaDefault())).collect(Collectors.joining("\n"))) + .replace("$mapParseLoop", generateParseLoop(generateCaseStatements(mapEntryFields), "map_entry_").indent(-DEFAULT_INDENT)) + .replace("$maxSize", String.valueOf(field.maxSize())) + ); } else { sb.append(("final var value = " + readMethod(field) + ";\n").indent(DEFAULT_INDENT)); } @@ -316,16 +369,28 @@ private static void generateFieldCaseStatement(final StringBuilder sb, final Fie oneOfField.getEnumClassRef() + '.' + Common.camelToUpperSnake(field.name()) + ", value);\n"); } else if (field.repeated()) { sb.append("if (temp_" + field.name() + ".size() >= " + field.maxSize() + ") {\n"); - sb.append(" throw new ParseException(\"" + field.name() + " size \" + temp_" + field.name() + ".size() + \" is greater than max \" + " + field.maxSize() + ");\n"); - sb.append("}\n"); - sb.append("temp_" + field.name() + " = addToList(temp_" + field.name() + ",value);\n"); + sb.append(" throw new ParseException(\"" + field.name() + " size \" + temp_" + field.name() + ".size() + \" is greater than max \" + " + field.maxSize() + ");\n"); + sb.append(" }\n"); + sb.append(" temp_" + field.name() + " = addToList(temp_" + field.name() + ",value);\n"); + } else if (field.type() == Field.FieldType.MAP) { + final MapField mapField = (MapField) field; + + sb.append("if (__map_messageLength != 0) {\n"); + sb.append(" if (temp_" + field.name() + ".size() >= " + field.maxSize() + ") {\n"); + sb.append(" throw new ParseException(\"" + field.name() + " size \" + temp_" + field.name() + ".size() + \" is greater than max \" + " + field.maxSize() + ");\n"); + sb.append(" }\n"); + sb.append(" temp_" + field.name() + " = addToMap(temp_" + field.name() + ", temp_$key, temp_$value);\n" + .replace("$key", mapField.keyField().name()) + .replace("$value", mapField.valueField().name()) + ); + sb.append(" }\n"); } else { sb.append("temp_" + field.name() + " = value;\n"); } sb.append("}\n"); } - private static String readMethod(Field field) { + static String readMethod(Field field) { if (field.optionalValueType()) { return switch (field.messageType()) { case "StringValue" -> "readString(input, " + field.maxSize() + ")"; @@ -359,6 +424,7 @@ private static String readMethod(Field field) { case BYTES -> "readBytes(input, " + field.maxSize() + ")"; case MESSAGE -> field.parseCode(); case ONE_OF -> throw new PbjCompilerException("Should never happen, oneOf handled elsewhere"); + case MAP -> throw new PbjCompilerException("Should never happen, map handled elsewhere"); }; } } diff --git a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/protobuf/CodecWriteMethodGenerator.java b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/protobuf/CodecWriteMethodGenerator.java index 63ac89df..415e0a94 100644 --- a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/protobuf/CodecWriteMethodGenerator.java +++ b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/protobuf/CodecWriteMethodGenerator.java @@ -4,11 +4,13 @@ import com.hedera.pbj.compiler.impl.Common; import com.hedera.pbj.compiler.impl.Field; +import com.hedera.pbj.compiler.impl.MapField; import com.hedera.pbj.compiler.impl.OneOfField; import com.hedera.pbj.compiler.impl.SingleField; import java.util.Comparator; import java.util.List; +import java.util.function.Function; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -18,12 +20,11 @@ final class CodecWriteMethodGenerator { static String generateWriteMethod(final String modelClassName, final List fields) { - final String fieldWriteLines = fields.stream() - .flatMap(field -> field.type() == Field.FieldType.ONE_OF ? ((OneOfField)field).fields().stream() : Stream.of(field)) - .sorted(Comparator.comparingInt(Field::fieldNumber)) - .map(field -> generateFieldWriteLines(field, modelClassName, "data.%s()".formatted(field.nameCamelFirstLower()))) - .collect(Collectors.joining("\n")) - .indent(DEFAULT_INDENT); + final String fieldWriteLines = buildFieldWriteLines( + modelClassName, + fields, + field -> "data.%s()".formatted(field.nameCamelFirstLower()), + true); return """ /** @@ -42,6 +43,18 @@ public void write(@NonNull $modelClass data, @NonNull final WritableSequentialDa .indent(DEFAULT_INDENT); } + private static String buildFieldWriteLines( + final String modelClassName, + final List fields, + final Function getValueBuilder, + final boolean skipDefault) { + return fields.stream() + .flatMap(field -> field.type() == Field.FieldType.ONE_OF ? ((OneOfField)field).fields().stream() : Stream.of(field)) + .sorted(Comparator.comparingInt(Field::fieldNumber)) + .map(field -> generateFieldWriteLines(field, modelClassName, getValueBuilder.apply(field), skipDefault)) + .collect(Collectors.joining("\n")) + .indent(DEFAULT_INDENT); + } /** * Generate lines of code for writing field @@ -49,9 +62,10 @@ public void write(@NonNull $modelClass data, @NonNull final WritableSequentialDa * @param field The field to generate writing line of code for * @param modelClassName The model class name for model class for message type we are generating writer for * @param getValueCode java code to get the value of field + * @param skipDefault skip writing the field if it has default value (for non-oneOf only) * @return java code to write field to output */ - private static String generateFieldWriteLines(final Field field, final String modelClassName, String getValueCode) { + private static String generateFieldWriteLines(final Field field, final String modelClassName, String getValueCode, boolean skipDefault) { final String fieldDef = Common.camelToUpperSnake(field.name()); String prefix = "// ["+field.fieldNumber()+"] - "+field.name(); prefix += "\n"; @@ -96,19 +110,70 @@ private static String generateFieldWriteLines(final Field field, final String mo default -> "write%sList(out, %s, %s);" .formatted(writeMethodName, fieldDef, getValueCode); }; + } else if (field.type() == Field.FieldType.MAP) { + // https://protobuf.dev/programming-guides/proto3/#maps + // On the wire, a map is equivalent to: + // message MapFieldEntry { + // key_type key = 1; + // value_type value = 2; + // } + // repeated MapFieldEntry map_field = N; + // NOTE: we serialize the map in the natural order of keys by design, + // so that the binary representation of the map is deterministic. + // NOTE: protoc serializes default values (e.g. "") in maps, so we should too. + final MapField mapField = (MapField) field; + final List mapEntryFields = List.of(mapField.keyField(), mapField.valueField()); + final Function getValueBuilder = mapEntryField -> + mapEntryField == mapField.keyField() ? "k" : (mapEntryField == mapField.valueField() ? "v" : null); + final String fieldWriteLines = buildFieldWriteLines( + field.name(), + mapEntryFields, + getValueBuilder, + false); + final String fieldSizeOfLines = CodecMeasureRecordMethodGenerator.buildFieldSizeOfLines( + field.name(), + mapEntryFields, + getValueBuilder, + false); + return prefix + """ + if (!$map.isEmpty()) { + final Pbj$javaFieldType pbjMap = (Pbj$javaFieldType) $map; + final int mapSize = pbjMap.size(); + for (int i = 0; i < mapSize; i++) { + writeTag(out, $fieldDef, WIRE_TYPE_DELIMITED); + $K k = pbjMap.getSortedKeys().get(i); + $V v = pbjMap.get(k); + int size = 0; + $fieldSizeOfLines + out.writeVarInt(size, false); + $fieldWriteLines + } + } + """ + .replace("$fieldDef", fieldDef) + .replace("$map", getValueCode) + .replace("$javaFieldType", mapField.javaFieldType()) + .replace("$K", mapField.keyField().type().boxedType) + .replace("$V", mapField.valueField().type() == Field.FieldType.MESSAGE ? ((SingleField)mapField.valueField()).messageType() : mapField.valueField().type().boxedType) + .replace("$fieldWriteLines", fieldWriteLines.indent(DEFAULT_INDENT)) + .replace("$fieldSizeOfLines", fieldSizeOfLines.indent(DEFAULT_INDENT)) + + ; } else { return prefix + switch(field.type()) { case ENUM -> "writeEnum(out, %s, %s);" .formatted(fieldDef, getValueCode); - case STRING -> "writeString(out, %s, %s);" - .formatted(fieldDef,getValueCode); + case STRING -> "writeString(out, %s, %s, %s);" + .formatted(fieldDef, getValueCode, skipDefault); case MESSAGE -> "writeMessage(out, $fieldDef, $valueCode, $codec::write, $codec::measureRecord);" .replace("$fieldDef", fieldDef) .replace("$valueCode", getValueCode) .replace("$codec", ((SingleField)field).messageTypeModelPackage() + "." + Common.capitalizeFirstLetter(field.messageType())+ ".PROTOBUF"); - case BOOL -> "writeBoolean(out, %s, %s);" - .formatted(fieldDef,getValueCode); + case BOOL -> "writeBoolean(out, %s, %s, %s);" + .formatted(fieldDef, getValueCode, skipDefault); + case INT32, UINT32, SINT32, FIXED32, SFIXED32, INT64, SINT64, UINT64, FIXED64, SFIXED64, BYTES -> "write%s(out, %s, %s, %s);" + .formatted(writeMethodName, fieldDef, getValueCode, skipDefault); default -> "write%s(out, %s, %s);" .formatted(writeMethodName, fieldDef, getValueCode); }; diff --git a/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/FieldType.java b/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/FieldType.java index 28de6e04..1f596a9f 100644 --- a/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/FieldType.java +++ b/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/FieldType.java @@ -37,7 +37,9 @@ public enum FieldType { /** Protobuf enum type */ ENUM, /** Protobuf sub-message type */ - MESSAGE; + MESSAGE, + /** Protobuf map type */ + MAP; /** * Optional values have an inner field, with a standard definition for every FieldType. We create singleton diff --git a/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/JsonTools.java b/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/JsonTools.java index d4a11fe0..6a242874 100644 --- a/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/JsonTools.java +++ b/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/JsonTools.java @@ -16,6 +16,9 @@ import java.nio.CharBuffer; import java.util.Base64; import java.util.List; +import java.util.Map; +import java.util.function.BiFunction; +import java.util.function.Function; import java.util.stream.Collectors; /** @@ -185,6 +188,56 @@ public static boolean parseBoolean(JSONParser.ValueContext valueContext) { return Boolean.parseBoolean(valueContext.getText()); } + /** + * Parse an integer from a JSONParser.PairContext + * + * @param pairContext the JSONParser.PairContext to parse + * @return the parsed integer + */ + public static int parseInteger(JSONParser.PairContext pairContext) { + return Integer.parseInt(pairContext.STRING().getText()); + } + + /** + * Parse a long from a JSONParser.PairContext + * + * @param pairContext the JSONParser.PairContext to parse + * @return the parsed long + */ + public static long parseLong(JSONParser.PairContext pairContext) { + return Long.parseLong(pairContext.STRING().getText()); + } + + /** + * Parse a float from a JSONParser.PairContext + * + * @param pairContext the JSONParser.PairContext to parse + * @return the parsed float + */ + public static float parseFloat(JSONParser.PairContext pairContext) { + return Float.parseFloat(pairContext.STRING().getText()); + } + + /** + * Parse a double from a JSONParser.PairContext + * + * @param pairContext the JSONParser.PairContext to parse + * @return the parsed double + */ + public static double parseDouble(JSONParser.PairContext pairContext) { + return Double.parseDouble(pairContext.STRING().getText()); + } + + /** + * Parse a boolean from a JSONParser.PairContext + * + * @param pairContext the JSONParser.PairContext to parse + * @return the parsed boolean + */ + public static boolean parseBoolean(JSONParser.PairContext pairContext) { + return Boolean.parseBoolean(pairContext.STRING().getText()); + } + // ==================================================================================================== // To JSON String Methods @@ -249,6 +302,32 @@ public static String field(String fieldName, byte[] value) { return rawFieldCode(fieldName, '"' + Base64.getEncoder().encodeToString(value) + '"'); } + /** + * Map field to JSON string + * + * @param fieldName the name of the field + * @param value the value of the field + * @param kEncoder an encoder of a key value to a string + * @param vComposer a composer of a "key":value strings - basically, a JsonTools::field method for the value type + * @return the JSON string + */ + public static String field(String fieldName, Map value, Function kEncoder, BiFunction vComposer) { + assert !value.isEmpty(); + StringBuilder sb = new StringBuilder(); + PbjMap pbjMap = (PbjMap) value; + for (int i = 0; i < pbjMap.size(); i++) { + if (i > 0) { + sb.append(",\n"); + } + K k = pbjMap.getSortedKeys().get(i); + V v = pbjMap.get(k); + + String keyStr = kEncoder.apply(k); + sb.append(vComposer.apply(keyStr, v)); + } + return rawFieldCode(fieldName, "{\n" + sb.toString().indent(4) + " }"); + } + /** * Primitive boolean field to JSON string * @@ -422,6 +501,7 @@ public static String arrayField(String fieldName, case BOOL -> Boolean.toString((Boolean) item); case ENUM -> '"' + ((EnumWithProtoMetadata)item).protoName() + '"'; case MESSAGE -> throw new UnsupportedOperationException("No expected here should have called other arrayField() method"); + case MAP -> throw new UnsupportedOperationException("Arrays of maps not supported"); }; } }) diff --git a/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/PbjMap.java b/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/PbjMap.java new file mode 100644 index 00000000..34d42262 --- /dev/null +++ b/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/PbjMap.java @@ -0,0 +1,137 @@ +package com.hedera.pbj.runtime; + +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; + +/** + * Implements an immutable map that exposes a list of keys sorted in their natural order. + *

+ * This Map implementation allows one to iterate the entries in a deterministic order + * which is useful for serializing, hash computation, etc. + * + * @param key type + * @param value type + */ +public class PbjMap implements Map { + /** An empty PbjMap. */ + public static final PbjMap EMPTY = new PbjMap(Collections.EMPTY_MAP); + + private final Map map; + private final List sortedKeys; + + private PbjMap(final Map map) { + this.map = Collections.unmodifiableMap(map); + this.sortedKeys = Collections.unmodifiableList(map.keySet().stream().sorted().toList()); + } + + /** + * A public factory method for PbjMap objects. + * It returns the PbjMap.EMPTY if the input map is empty. + * It returns the map itself if the input map is an instance of PbjMap (because it's immutable anyway.) + * Otherwise, it returns a new PbjMap instance delegating to the provided input map. + * NOTE: the caller code is expected to never modify the input map after this factory method is called, + * otherwise the behavior is undefined. + * @param map an input map + * @return a PbjMap instance corresponding to the input map + * @param key type + * @param value type + */ + public static PbjMap of(final Map map) { + if (map == null || map.isEmpty()) return (PbjMap) EMPTY; + if (map instanceof PbjMap) return (PbjMap) map; + return new PbjMap<>(map); + } + + /** + * Return a list of keys sorted in their natural order. + * @return the sorted keys list + */ + public List getSortedKeys() { + return sortedKeys; + } + + @Override + public int size() { + return map.size(); + } + + @Override + public boolean isEmpty() { + return map.isEmpty(); + } + + @Override + public boolean containsKey(Object key) { + return map.containsKey(key); + } + + @Override + public boolean containsValue(Object value) { + return map.containsValue(value); + } + + @Override + public V get(Object key) { + return map.get(key); + } + + @Override + public V put(K key, V value) { + throw new UnsupportedOperationException("The map is immutable"); + } + + @Override + public V remove(Object key) { + throw new UnsupportedOperationException("The map is immutable"); + } + + @Override + public void putAll(Map m) { + throw new UnsupportedOperationException("The map is immutable"); + } + + @Override + public void clear() { + throw new UnsupportedOperationException("The map is immutable"); + } + + @Override + public Set keySet() { + return map.keySet(); + } + + @Override + public Collection values() { + return map.values(); + } + + @Override + public Set> entrySet() { + return map.entrySet(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + PbjMap pbjMap = (PbjMap) o; + return Objects.equals(map, pbjMap.map) && Objects.equals(sortedKeys, pbjMap.sortedKeys); + } + + @Override + public int hashCode() { + // This is a convenience hashCode() implementation that delegates to Java hashCode, + // and it's implemented here solely to support the above equals() method override. + // Generated protobuf models compute map fields' hash codes differently and deterministically. + return 31 * map.hashCode() + sortedKeys.hashCode(); + } + + @Override + public String toString() { + return map.toString() + " with sortedKeys: " + getSortedKeys(); + } +} diff --git a/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/ProtoParserTools.java b/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/ProtoParserTools.java index 2e546978..c5570536 100644 --- a/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/ProtoParserTools.java +++ b/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/ProtoParserTools.java @@ -12,7 +12,9 @@ import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; /** * This class is full of parse helper methods, they depend on a DataInput as input with position and limit set @@ -48,6 +50,25 @@ public static List addToList(List list, T newItem) { return list; } + /** + * Add an entry to a map returning a new map with the entry or the same map with the entry added. If the map is + * Collections.EMPTY_MAP then a new map is created and returned with the entry added. + * + * @param map The map to add entry to or Collections.EMPTY_MAP + * @param key The key + * @param value The value + * @return The map passed in if mutable or new map + * @param The type of keys + * @param The type of values + */ + public static Map addToMap(Map map, final K key, final V value) { + if (map == PbjMap.EMPTY) { + map = new HashMap<>(); + } + map.put(key, value); + return map; + } + /** * Read a protobuf int32 from input * diff --git a/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/ProtoTestTools.java b/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/ProtoTestTools.java index df0e0833..b422d9c0 100644 --- a/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/ProtoTestTools.java +++ b/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/ProtoTestTools.java @@ -8,6 +8,7 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.Map; /** * Static tools and test cases used by generated test classes. diff --git a/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/ProtoWriterTools.java b/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/ProtoWriterTools.java index 1b36b5e1..b7ffa004 100644 --- a/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/ProtoWriterTools.java +++ b/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/ProtoWriterTools.java @@ -8,6 +8,7 @@ import com.hedera.pbj.runtime.io.WritableSequentialData; import com.hedera.pbj.runtime.io.buffer.Bytes; import com.hedera.pbj.runtime.io.buffer.RandomAccessData; +import edu.umd.cs.findbugs.annotations.NonNull; import edu.umd.cs.findbugs.annotations.Nullable; import java.io.IOException; @@ -46,7 +47,7 @@ public static ProtoConstants wireType(final FieldDefinition field) { case FIXED32, SFIXED32 -> WIRE_TYPE_FIXED_32_BIT; case FIXED64, SFIXED64 -> WIRE_TYPE_FIXED_64_BIT; case BOOL -> WIRE_TYPE_VARINT_OR_ZIGZAG; - case BYTES, MESSAGE, STRING -> WIRE_TYPE_DELIMITED; + case BYTES, MESSAGE, STRING, MAP -> WIRE_TYPE_DELIMITED; case ENUM -> WIRE_TYPE_VARINT_OR_ZIGZAG; }; } @@ -89,13 +90,25 @@ private static RuntimeException unsupported() { * @param value the int value to write */ public static void writeInteger(WritableSequentialData out, FieldDefinition field, int value) { + writeInteger(out, field, value, true); + } + + /** + * Write a integer to data output + * + * @param out The data output to write to + * @param field the descriptor for the field we are writing + * @param value the int value to write + * @param skipDefault default value results in no-op for non-oneOf + */ + public static void writeInteger(WritableSequentialData out, FieldDefinition field, int value, boolean skipDefault) { assert switch(field.type()) { case INT32, UINT32, SINT32, FIXED32, SFIXED32 -> true; default -> false; } : "Not an integer type " + field; assert !field.repeated() : "Use writeIntegerList with repeated types"; - if (!field.oneOf() && value == 0) { + if (skipDefault && !field.oneOf() && value == 0) { return; } switch (field.type()) { @@ -129,12 +142,24 @@ assert switch(field.type()) { * @param value the long value to write */ public static void writeLong(WritableSequentialData out, FieldDefinition field, long value) { + writeLong(out, field, value, true); + } + + /** + * Write a long to data output + * + * @param out The data output to write to + * @param field the descriptor for the field we are writing + * @param value the long value to write + * @param skipDefault default value results in no-op for non-oneOf + */ + public static void writeLong(WritableSequentialData out, FieldDefinition field, long value, boolean skipDefault) { assert switch(field.type()) { case INT64, UINT64, SINT64, FIXED64, SFIXED64 -> true; default -> false; } : "Not a long type " + field; assert !field.repeated() : "Use writeLongList with repeated types"; - if (!field.oneOf() && value == 0) { + if (skipDefault && !field.oneOf() && value == 0) { return; } switch (field.type()) { @@ -200,10 +225,22 @@ public static void writeDouble(WritableSequentialData out, FieldDefinition field * @param value the boolean value to write */ public static void writeBoolean(WritableSequentialData out, FieldDefinition field, boolean value) { + writeBoolean(out, field, value, true); + } + + /** + * Write a boolean to data output + * + * @param out The data output to write to + * @param field the descriptor for the field we are writing + * @param value the boolean value to write + * @param skipDefault default value results in no-op for non-oneOf + */ + public static void writeBoolean(WritableSequentialData out, FieldDefinition field, boolean value, boolean skipDefault) { assert field.type() == FieldType.BOOL : "Not a boolean type " + field; assert !field.repeated() : "Use writeBooleanList with repeated types"; // In the case of oneOf we write the value even if it is default value of false - if (value || field.oneOf()) { + if (value || field.oneOf() || !skipDefault) { writeTag(out, field, WIRE_TYPE_VARINT_OR_ZIGZAG); out.writeByte(value ? (byte)1 : 0); } @@ -236,10 +273,24 @@ public static void writeEnum(WritableSequentialData out, FieldDefinition field, * @throws IOException If a I/O error occurs */ public static void writeString(final WritableSequentialData out, final FieldDefinition field, - final String value) throws IOException { + final String value) throws IOException { + writeString(out, field, value, true); + } + + /** + * Write a string to data output, assuming the field is non-repeated. + * + * @param out The data output to write to + * @param field the descriptor for the field we are writing, the field must be non-repeated + * @param value the string value to write + * @param skipDefault default value results in no-op for non-oneOf + * @throws IOException If a I/O error occurs + */ + public static void writeString(final WritableSequentialData out, final FieldDefinition field, + final String value, boolean skipDefault) throws IOException { assert field.type() == FieldType.STRING : "Not a string type " + field; assert !field.repeated() : "Use writeStringList with repeated types"; - writeStringNoChecks(out, field, value); + writeStringNoChecks(out, field, value, skipDefault); } /** @@ -268,14 +319,28 @@ public static void writeOneRepeatedString(final WritableSequentialData out, fina * @throws IOException If a I/O error occurs */ private static void writeStringNoChecks(final WritableSequentialData out, final FieldDefinition field, - final String value) throws IOException { + final String value) throws IOException { + writeStringNoChecks(out, field, value, true); + } + + /** + * Write a integer to data output - no validation checks. + * + * @param out The data output to write to + * @param field the descriptor for the field we are writing + * @param value the string value to write + * @param skipDefault default value results in no-op for non-oneOf + * @throws IOException If a I/O error occurs + */ + private static void writeStringNoChecks(final WritableSequentialData out, final FieldDefinition field, + final String value, boolean skipDefault) throws IOException { // When not a oneOf don't write default value - if (!field.oneOf() && (value == null || value.isEmpty())) { + if (skipDefault && !field.oneOf() && (value == null || value.isEmpty())) { return; } writeTag(out, field, WIRE_TYPE_DELIMITED); out.writeVarInt(sizeOfStringNoTag(value), false); - Utf8Tools.encodeUtf8(value,out); + Utf8Tools.encodeUtf8(value, out); } /** @@ -288,10 +353,25 @@ private static void writeStringNoChecks(final WritableSequentialData out, final * @throws IOException If a I/O error occurs */ public static void writeBytes(final WritableSequentialData out, final FieldDefinition field, - final RandomAccessData value) throws IOException { + final RandomAccessData value) throws IOException { + writeBytes(out, field, value, true); + } + + /** + * Write a bytes to data output, assuming the corresponding field is non-repeated, and field type + * is any delimited: bytes, string, or message. + * + * @param out The data output to write to + * @param field the descriptor for the field we are writing, the field must not be repeated + * @param value the bytes value to write + * @param skipDefault default value results in no-op for non-oneOf + * @throws IOException If a I/O error occurs + */ + public static void writeBytes(final WritableSequentialData out, final FieldDefinition field, + final RandomAccessData value, boolean skipDefault) throws IOException { assert field.type() == FieldType.BYTES : "Not a byte[] type " + field; assert !field.repeated() : "Use writeBytesList with repeated types"; - writeBytesNoChecks(out, field, value, true); + writeBytesNoChecks(out, field, value, skipDefault); } /** @@ -403,6 +483,38 @@ private static void writeMessageNoChecks(final WritableSequentialData out, f } } + public static void writeMap( + final WritableSequentialData out, + final FieldDefinition field, + @NonNull final PbjMap map, + final ProtoWriter kWriter, + final ProtoWriter vWriter, + final ToIntFunction sizeOfK, + final ToIntFunction sizeOfV + ) throws IOException { + // https://protobuf.dev/programming-guides/proto3/#maps + // On the wire, a map is equivalent to: + // message MapFieldEntry { + // key_type key = 1; + // value_type value = 2; + // } + // repeated MapFieldEntry map_field = N; + if (map.isEmpty()) { + return; + } + final int size = map.size(); + for (int i = 0; i < size; i++) { + K k = map.getSortedKeys().get(i); + V v = map.get(k); + writeTag(out, field, WIRE_TYPE_DELIMITED); + final int sizeK = sizeOfK.applyAsInt(k); + final int sizeV = sizeOfV.applyAsInt(v); + out.writeVarInt(sizeK + sizeV, false); + kWriter.write(k, out); + vWriter.write(v, out); + } + } + // ================================================================================================================ // OPTIONAL VERSIONS OF WRITE METHODS @@ -418,7 +530,7 @@ public static void writeOptionalInteger(WritableSequentialData out, FieldDefinit writeTag(out, field, WIRE_TYPE_DELIMITED); final var newField = field.type().optionalFieldDefinition; out.writeVarInt(sizeOfInteger(newField, value), false); - writeInteger(out,newField,value); + writeInteger(out, newField, value); } } @@ -434,7 +546,7 @@ public static void writeOptionalLong(WritableSequentialData out, FieldDefinition writeTag(out, field, WIRE_TYPE_DELIMITED); final var newField = field.type().optionalFieldDefinition; out.writeVarInt(sizeOfLong(newField, value), false); - writeLong(out,newField,value); + writeLong(out, newField, value); } } @@ -482,7 +594,7 @@ public static void writeOptionalBoolean(WritableSequentialData out, FieldDefinit writeTag(out, field, WIRE_TYPE_DELIMITED); final var newField = field.type().optionalFieldDefinition; out.writeVarInt(sizeOfBoolean(newField, value), false); - writeBoolean(out,newField,value); + writeBoolean(out, newField, value); } } @@ -499,7 +611,7 @@ public static void writeOptionalString(WritableSequentialData out, FieldDefiniti writeTag(out, field, WIRE_TYPE_DELIMITED); final var newField = field.type().optionalFieldDefinition; out.writeVarInt(sizeOfString(newField, value), false); - writeString(out,newField,value); + writeString(out, newField, value); } } @@ -1036,7 +1148,7 @@ public static int sizeOfOptionalBoolean(FieldDefinition field, @Nullable Boolean */ public static int sizeOfOptionalString(FieldDefinition field, @Nullable String value) { if (value != null) { - final int size = sizeOfString(field.type().optionalFieldDefinition,value); + final int size = sizeOfString(field.type().optionalFieldDefinition, value); return sizeOfTag(field, WIRE_TYPE_DELIMITED) + sizeOfUnsignedVarInt32(size) + size; } return 0; @@ -1065,7 +1177,19 @@ public static int sizeOfOptionalBytes(FieldDefinition field, @Nullable RandomAcc * @return the number of bytes for encoded value */ public static int sizeOfInteger(FieldDefinition field, int value) { - if (!field.oneOf() && value == 0) return 0; + return sizeOfInteger(field, value, true); + } + + /** + * Get number of bytes that would be needed to encode an integer field + * + * @param field descriptor of field + * @param value integer value to get encoded size for + * @param skipDefault default value results in zero size + * @return the number of bytes for encoded value + */ + public static int sizeOfInteger(FieldDefinition field, int value, boolean skipDefault) { + if (skipDefault && !field.oneOf() && value == 0) return 0; return switch (field.type()) { case INT32 -> sizeOfTag(field, WIRE_TYPE_VARINT_OR_ZIGZAG) + sizeOfVarInt32(value); case UINT32 -> sizeOfTag(field, WIRE_TYPE_VARINT_OR_ZIGZAG) + sizeOfUnsignedVarInt32(value); @@ -1083,7 +1207,19 @@ public static int sizeOfInteger(FieldDefinition field, int value) { * @return the number of bytes for encoded value */ public static int sizeOfLong(FieldDefinition field, long value) { - if (!field.oneOf() && value == 0) return 0; + return sizeOfLong(field, value, true); + } + + /** + * Get number of bytes that would be needed to encode a long field + * + * @param field descriptor of field + * @param value long value to get encoded size for + * @param skipDefault default value results in zero size + * @return the number of bytes for encoded value + */ + public static int sizeOfLong(FieldDefinition field, long value, boolean skipDefault) { + if (skipDefault && !field.oneOf() && value == 0) return 0; return switch (field.type()) { case INT64, UINT64 -> sizeOfTag(field, WIRE_TYPE_VARINT_OR_ZIGZAG) + sizeOfUnsignedVarInt64(value); case SINT64 -> sizeOfTag(field, WIRE_TYPE_VARINT_OR_ZIGZAG) + sizeOfUnsignedVarInt64((value << 1) ^ (value >> 63)); @@ -1124,7 +1260,19 @@ public static int sizeOfDouble(FieldDefinition field, double value) { * @return the number of bytes for encoded value */ public static int sizeOfBoolean(FieldDefinition field, boolean value) { - return (value || field.oneOf()) ? sizeOfTag(field, WIRE_TYPE_VARINT_OR_ZIGZAG) + 1 : 0; + return sizeOfBoolean(field, value, true); + } + + /** + * Get number of bytes that would be needed to encode a boolean field + * + * @param field descriptor of field + * @param value boolean value to get encoded size for + * @param skipDefault default value results in zero size + * @return the number of bytes for encoded value + */ + public static int sizeOfBoolean(FieldDefinition field, boolean value, boolean skipDefault) { + return (value || field.oneOf() || !skipDefault) ? sizeOfTag(field, WIRE_TYPE_VARINT_OR_ZIGZAG) + 1 : 0; } @@ -1150,8 +1298,20 @@ public static int sizeOfEnum(FieldDefinition field, EnumWithProtoMetadata enumVa * @return the number of bytes for encoded value */ public static int sizeOfString(FieldDefinition field, String value) { + return sizeOfString(field, value, true); + } + + /** + * Get number of bytes that would be needed to encode a string field + * + * @param field descriptor of field + * @param value string value to get encoded size for + * @param skipDefault default value results in zero size + * @return the number of bytes for encoded value + */ + public static int sizeOfString(FieldDefinition field, String value, boolean skipDefault) { // When not a oneOf don't write default value - if (!field.oneOf() && (value == null || value.isEmpty())) { + if (skipDefault && !field.oneOf() && (value == null || value.isEmpty())) { return 0; } return sizeOfDelimited(field, sizeOfStringNoTag(value)); @@ -1183,8 +1343,20 @@ private static int sizeOfStringNoTag(String value) { * @return the number of bytes for encoded value */ public static int sizeOfBytes(FieldDefinition field, RandomAccessData value) { + return sizeOfBytes(field, value, true); + } + + /** + * Get number of bytes that would be needed to encode a bytes field + * + * @param field descriptor of field + * @param value bytes value to get encoded size for + * @param skipDefault default value results in zero size + * @return the number of bytes for encoded value + */ + public static int sizeOfBytes(FieldDefinition field, RandomAccessData value, boolean skipDefault) { // When not a oneOf don't write default value - if (!field.oneOf() && (value.length() == 0)) { + if (skipDefault && !field.oneOf() && (value.length() == 0)) { return 0; } return sizeOfDelimited(field, (int) value.length()); diff --git a/pbj-integration-tests/src/main/proto/everything.proto b/pbj-integration-tests/src/main/proto/everything.proto index 3bf44885..57fd251b 100644 --- a/pbj-integration-tests/src/main/proto/everything.proto +++ b/pbj-integration-tests/src/main/proto/everything.proto @@ -33,6 +33,12 @@ message Everything { InnerEverything innerEverything = 50; + map mapInt32ToString = 71; + map mapBoolToDouble = 72; + map mapStringToMessage = 73; + map mapUInt64ToBytes = 74; + map mapInt64ToBool = 75; + repeated int32 int32NumberList = 100; repeated sint32 sint32NumberList = 102; repeated uint32 uint32NumberList = 103; diff --git a/pbj-integration-tests/src/main/proto/map.proto b/pbj-integration-tests/src/main/proto/map.proto new file mode 100644 index 00000000..a87ff84e --- /dev/null +++ b/pbj-integration-tests/src/main/proto/map.proto @@ -0,0 +1,26 @@ +syntax = "proto3"; + +package proto; + +option java_package = "com.hedera.pbj.test.proto.java"; +option java_multiple_files = true; +// <<>> This comment is special code for setting PBJ Compiler java package + +/** + * Sample protobuf containing maps. + */ +message MessageWithMaps { + /** A test map. */ + map mapInt32ToString = 1; +} + +/** + * Sample protobuf containing multiple different maps. + */ +message MessageWithManyMaps { + map mapInt32ToString = 1; + map mapBoolToDouble = 2; + map mapStringToMessage = 3; + map mapUInt64ToBytes = 4; + map mapInt64ToBool = 5; +} \ No newline at end of file diff --git a/pbj-integration-tests/src/test/java/com/hedera/pbj/intergration/test/SampleFuzzTest.java b/pbj-integration-tests/src/test/java/com/hedera/pbj/intergration/test/SampleFuzzTest.java index b8f0bee8..64b81ac5 100644 --- a/pbj-integration-tests/src/test/java/com/hedera/pbj/intergration/test/SampleFuzzTest.java +++ b/pbj-integration-tests/src/test/java/com/hedera/pbj/intergration/test/SampleFuzzTest.java @@ -70,7 +70,7 @@ public class SampleFuzzTest { * if the mean value of all the individual DESERIALIZATION_FAILED * shares is greater than this threshold. */ - private static final double DESERIALIZATION_FAILED_MEAN_THRESHOLD = .983; + private static final double DESERIALIZATION_FAILED_MEAN_THRESHOLD = .9829; /** * Fuzz tests are tagged with this tag to allow Gradle/JUnit