Skip to content

Commit

Permalink
fix: fixed union generator when processing union with custom enum value
Browse files Browse the repository at this point in the history
  • Loading branch information
sssooonnnggg committed Oct 25, 2024
1 parent 315b95b commit 2377038
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 41 deletions.
18 changes: 11 additions & 7 deletions src/FlatSharp.Compiler/SchemaModel/ReferenceUnionSchemaModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ public class ReferenceUnionSchemaModel : BaseSchemaModel
protected override void OnWriteCode(CodeWriter writer, CompileContext context)
{
List<(string resolvedType, EnumVal value, Type? propertyType)> innerTypes = new();
int itemIndex = 1;
foreach (var inner in this.union.Values.Select(x => x.Value))
{
// Skip "none".
Expand All @@ -61,8 +62,9 @@ protected override void OnWriteCode(CodeWriter writer, CompileContext context)
FlatSharpInternal.Assert(previousType is not null, "PreviousType was null");

propertyClrType = previousType
.GetProperty($"Item{inner.Value}", BindingFlags.Public | BindingFlags.Instance)?
.GetProperty($"Item{itemIndex}", BindingFlags.Public | BindingFlags.Instance)?
.PropertyType;
++itemIndex;

FlatSharpInternal.Assert(propertyClrType is not null, "Couldn't find property");
}
Expand Down Expand Up @@ -98,11 +100,13 @@ protected override void OnWriteCode(CodeWriter writer, CompileContext context)
writer.AppendLine();
writer.AppendLine("public byte Discriminator => (byte)this.discriminator;");

int index = 1;
foreach (var item in innerTypes)
{
this.WriteConstructor(writer, item.resolvedType, item.value, item.propertyType);
this.AddUnionMember(writer, item.resolvedType, item.value, item.propertyType, context);
this.WriteConstructor(writer, index, item.resolvedType, item.value, item.propertyType);
this.AddUnionMember(writer, index, item.resolvedType, item.value, item.propertyType, context);
this.WriteImplicitOperator(writer, item.resolvedType);
++index;
}

this.WriteDefaultConstructor(writer);
Expand All @@ -112,13 +116,13 @@ protected override void OnWriteCode(CodeWriter writer, CompileContext context)
}
}

private void AddUnionMember(CodeWriter writer, string resolvedType, EnumVal value, Type? propertyClrType, CompileContext context)
private void AddUnionMember(CodeWriter writer, int index, string resolvedType, EnumVal value, Type? propertyClrType, CompileContext context)
{
writer.AppendLine();
writer.AppendLine($"private {resolvedType}{(propertyClrType?.IsValueType == false ? "?" : string.Empty)} value_{value.Value};");

writer.AppendLine();
writer.AppendLine($"public {resolvedType} Item{value.Value}");
writer.AppendLine($"public {resolvedType} Item{index}");
using (writer.WithBlock())
{
writer.AppendLine("get");
Expand Down Expand Up @@ -233,7 +237,7 @@ private void WriteMatchMethod(
}
}

private void WriteConstructor(CodeWriter writer, string resolvedType, EnumVal unionValue, Type? propertyType)
private void WriteConstructor(CodeWriter writer, int index, string resolvedType, EnumVal unionValue, Type? propertyType)
{
writer.AppendLine($"public {this.Name}({resolvedType} value)");
using (writer.WithBlock())
Expand All @@ -248,7 +252,7 @@ private void WriteConstructor(CodeWriter writer, string resolvedType, EnumVal un
}

writer.AppendLine($"this.discriminator = {unionValue.Value};");
writer.AppendLine($"this.Item{unionValue.Value} = value;");
writer.AppendLine($"this.Item{index} = value;");
}
}

Expand Down
33 changes: 20 additions & 13 deletions src/FlatSharp.Compiler/SchemaModel/ValueUnionSchemaModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ protected override void OnWriteCode(CodeWriter writer, CompileContext context)
writer.AppendLine();
writer.AppendLine("public byte Discriminator { get; }");

int index = 1;
foreach (var item in innerTypes)
{
Type? propertyClrType = null;
Expand All @@ -168,21 +169,21 @@ protected override void OnWriteCode(CodeWriter writer, CompileContext context)
FlatSharpInternal.Assert(previousType is not null, "PreviousType was null");

propertyClrType = previousType
.GetProperty($"Item{item.value.Value}", System.Reflection.BindingFlags.Public | System.Reflection.BindingFlags.Instance)?
.GetProperty($"Item{index}", System.Reflection.BindingFlags.Public | System.Reflection.BindingFlags.Instance)?
.PropertyType;

FlatSharpInternal.Assert(propertyClrType is not null, "Couldn't find property");
}

this.WriteConstructor(writer, item.resolvedType, item.value, propertyClrType, generateUnsafeItems);
this.WriteImplicitOperator(writer, item.resolvedType);
this.WriteUncheckedGetItemMethod(writer, item.resolvedType, item.value, propertyClrType, generateUnsafeItems);
this.WriteUncheckedGetItemMethod(writer, index, item.resolvedType, item.value, propertyClrType, generateUnsafeItems);

writer.AppendLine();
writer.AppendLine($"public {item.resolvedType} {item.value.Key} => this.Item{item.value.Value};");
writer.AppendLine($"public {item.resolvedType} {item.value.Key} => this.Item{index};");

writer.AppendLine();
writer.AppendLine($"public {item.resolvedType} Item{item.value.Value}");
writer.AppendLine($"public {item.resolvedType} Item{index}");
using (writer.WithBlock())
{
writer.AppendLine("get");
Expand All @@ -194,7 +195,7 @@ protected override void OnWriteCode(CodeWriter writer, CompileContext context)
writer.AppendLine($"{typeof(FSThrow).GGCTN()}.{nameof(FSThrow.InvalidOperation_UnionIsNotOfType)}();");
}

writer.AppendLine($"return this.UncheckedGetItem{item.value.Value}();");
writer.AppendLine($"return this.UncheckedGetItem{index}();");
}
}

Expand Down Expand Up @@ -225,9 +226,11 @@ protected override void OnWriteCode(CodeWriter writer, CompileContext context)
writer.AppendLine("return false;");
}

writer.AppendLine($"value = this.UncheckedGetItem{item.value.Value}();");
writer.AppendLine($"value = this.UncheckedGetItem{index}();");
writer.AppendLine("return true;");
}

++index;
}

this.WriteAcceptMethod(writer, innerTypes);
Expand All @@ -253,10 +256,12 @@ private void WriteAcceptMethod(
writer.AppendLine("switch (disc)");
using (writer.WithBlock())
{
int index = 1;
foreach (var item in components)
{
long index = item.value.Value;
writer.AppendLine($"case {index}: return visitor.Visit(this.UncheckedGetItem{item.value.Value}());");
long value = item.value.Value;
writer.AppendLine($"case {value}: return visitor.Visit(this.UncheckedGetItem{index}());");
++index;
}

writer.AppendLine($"default:");
Expand Down Expand Up @@ -286,10 +291,12 @@ private void WriteMatchMethod(
writer.AppendLine("switch (disc)");
using (writer.WithBlock())
{
int index = 1;
foreach (var item in components)
{
long index = item.value.Value;
writer.AppendLine($"case {index}: return case{item.value.Key}(this.UncheckedGetItem{item.value.Value}());");
long value = item.value.Value;
writer.AppendLine($"case {value}: return case{item.value.Key}(this.UncheckedGetItem{index}());");
++index;
}

writer.AppendLine($"default:");
Expand All @@ -302,12 +309,12 @@ private void WriteMatchMethod(
}
}

private void WriteUncheckedGetItemMethod(CodeWriter writer, string resolvedType, EnumVal unionValue, Type? propertyType, bool generateUnsafeItems)
private void WriteUncheckedGetItemMethod(CodeWriter writer, int index, string resolvedType, EnumVal unionValue, Type? propertyType, bool generateUnsafeItems)
{
if (propertyType?.IsValueType == true && generateUnsafeItems)
{
writer.AppendLine();
writer.AppendLine($"private {resolvedType} UncheckedGetItem{unionValue.Value}()");
writer.AppendLine($"private {resolvedType} UncheckedGetItem{index}()");
using (writer.WithBlock())
{
writer.AppendLine($"FlatSharpInternal.AssertSizeOf<{resolvedType}>({propertyType.StructLayoutAttribute!.Size});");
Expand All @@ -322,7 +329,7 @@ private void WriteUncheckedGetItemMethod(CodeWriter writer, string resolvedType,
else
{
writer.AppendLine();
writer.AppendLine($"private {resolvedType} UncheckedGetItem{unionValue.Value}()");
writer.AppendLine($"private {resolvedType} UncheckedGetItem{index}()");
using (writer.WithBlock())
{
writer.AppendLine($"return ({resolvedType})this.value;");
Expand Down
62 changes: 41 additions & 21 deletions src/FlatSharp/TypeModel/UnionTypeModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ namespace FlatSharp.TypeModel;
/// </summary>
public class UnionTypeModel : RuntimeTypeModel
{
private ITypeModel[] memberTypeModels;
private (byte, ITypeModel)[] memberTypeModels;

internal UnionTypeModel(Type unionType, TypeModelContainer provider) : base(unionType, provider)
{
Expand Down Expand Up @@ -70,25 +70,25 @@ internal UnionTypeModel(Type unionType, TypeModelContainer provider) : base(unio
/// <summary>
/// Gets the type model for this union's members. Index 0 corresponds to discriminator 1.
/// </summary>
public ITypeModel[] UnionElementTypeModel => this.memberTypeModels;
public (byte, ITypeModel)[] UnionElementTypeModel => this.memberTypeModels;

/// <summary>
/// We need it to pass through.
/// </summary>
public override TableFieldContextRequirements TableFieldContextRequirements =>
this.memberTypeModels.Select(x => x.TableFieldContextRequirements).Aggregate(TableFieldContextRequirements.None, (a, b) => a | b);
this.memberTypeModels.Select(model => model.Item2).Select(x => x.TableFieldContextRequirements).Aggregate(TableFieldContextRequirements.None, (a, b) => a | b);

/// <summary>
/// Unions have an implicit dependency on <see cref="byte"/> for the discriminator.
/// </summary>
public override IEnumerable<ITypeModel> Children => this.memberTypeModels.Concat(new[] { this.typeModelContainer.CreateTypeModel(typeof(byte)) });
public override IEnumerable<ITypeModel> Children => this.memberTypeModels.Select(model => model.Item2).Concat(new[] { this.typeModelContainer.CreateTypeModel(typeof(byte)) });

public override CodeGeneratedMethod CreateGetMaxSizeMethodBody(GetMaxSizeCodeGenContext context)
{
List<string> switchCases = new List<string>();
for (int i = 0; i < this.UnionElementTypeModel.Length; ++i)
{
var unionMember = this.UnionElementTypeModel[i];
var (enumVal, unionMember) = this.UnionElementTypeModel[i];
int unionIndex = i + 1;

var itemContext = context with
Expand All @@ -98,7 +98,7 @@ public override CodeGeneratedMethod CreateGetMaxSizeMethodBody(GetMaxSizeCodeGen

string @case =
$@"
case {unionIndex}:
case {enumVal}:
return {sizeof(uint) + SerializationHelpers.GetMaxPadding(sizeof(uint))} + {itemContext.GetMaxSizeInvocation(unionMember.ClrType)};";

switchCases.Add(@case);
Expand Down Expand Up @@ -127,7 +127,7 @@ public override CodeGeneratedMethod CreateParseMethodBody(ParserCodeGenContext c

for (int i = 0; i < this.UnionElementTypeModel.Length; ++i)
{
var unionMember = this.UnionElementTypeModel[i];
var (enumVal, unionMember) = this.UnionElementTypeModel[i];
int unionIndex = i + 1;

string inlineAdjustment = string.Empty;
Expand All @@ -144,7 +144,7 @@ public override CodeGeneratedMethod CreateParseMethodBody(ParserCodeGenContext c

string @case =
$@"
case {unionIndex}:
case {enumVal}:
{inlineAdjustment}
return {createNew}({itemContext.GetParseInvocation(unionMember.ClrType)});
";
Expand Down Expand Up @@ -182,8 +182,9 @@ public override CodeGeneratedMethod CreateParseMethodBody(ParserCodeGenContext c

for (int i = 0; i < this.UnionElementTypeModel.Length; ++i)
{
int unionIndex = i + 1; // unions start at 1.
string itemType = this.UnionElementTypeModel[i].GetGlobalCompilableTypeName();
int unionIndex = i + 1;
var (enumVal, unionMember) = this.UnionElementTypeModel[i];
string itemType = unionMember.GetGlobalCompilableTypeName();

getOrCreates.Add($@"
public static {className} GetOrCreate({itemType} value)
Expand All @@ -193,7 +194,7 @@ public override CodeGeneratedMethod CreateParseMethodBody(ParserCodeGenContext c
union = new {className}();
}}
union.discriminator = {unionIndex};
union.discriminator = {enumVal};
union.Item{unionIndex} = value;
union.isAlive = 1;
Expand All @@ -202,13 +203,13 @@ public override CodeGeneratedMethod CreateParseMethodBody(ParserCodeGenContext c
");

string recursiveReturn = string.Empty;
if (typeof(IPoolableObject).IsAssignableFrom(this.UnionElementTypeModel[i].ClrType))
if (typeof(IPoolableObject).IsAssignableFrom(unionMember.ClrType))
{
recursiveReturn = $"this.Item{unionIndex}?.ReturnToPool(true);";
}

returnToPoolCases.Add($@"
case {unionIndex}:
case {enumVal}:
{{
{recursiveReturn}
this.Item{unionIndex} = default({itemType})!;
Expand Down Expand Up @@ -259,7 +260,7 @@ public override CodeGeneratedMethod CreateSerializeMethodBody(SerializationCodeG
List<string> switchCases = new List<string>();
for (int i = 0; i < this.UnionElementTypeModel.Length; ++i)
{
var elementModel = this.UnionElementTypeModel[i];
var (enumVal, elementModel) = this.UnionElementTypeModel[i];
var unionIndex = i + 1;

string inlineAdjustment;
Expand Down Expand Up @@ -287,7 +288,7 @@ public override CodeGeneratedMethod CreateSerializeMethodBody(SerializationCodeG

string @case =
$@"
case {unionIndex}:
case {enumVal}:
{{
{inlineAdjustment}
{caseContext.GetSerializeInvocation(elementModel.ClrType)};
Expand Down Expand Up @@ -321,11 +322,12 @@ public override CodeGeneratedMethod CreateCloneMethodBody(CloneCodeGenContext co

for (int i = 0; i < this.memberTypeModels.Length; ++i)
{
int discriminator = i + 1;
string cloneMethod = context.MethodNameMap[this.memberTypeModels[i].ClrType];
int index = i + 1;
var (enumVal, unionMember) = this.UnionElementTypeModel[i];
var cloneMethod = context.MethodNameMap[unionMember.ClrType];
switchCases.Add($@"
case {discriminator}:
return new {this.GetGlobalCompilableTypeName()}({cloneMethod}({context.ItemVariableName}.Item{discriminator}));
case {enumVal}:
return new {this.GetGlobalCompilableTypeName()}({cloneMethod}({context.ItemVariableName}.Item{index}));
");
}

Expand All @@ -349,15 +351,33 @@ public override void Initialize()
Type unionType = this.ClrType.GetInterfaces()
.Single(x => x != typeof(IFlatBufferUnion) && typeof(IFlatBufferUnion).IsAssignableFrom(x));

this.memberTypeModels = unionType.GetGenericArguments().Select(this.typeModelContainer.CreateTypeModel).ToArray();
// Get enum value in union type
var enumFields = Enum.GetValues(this.ClrType.GetNestedType("ItemKind")!);
List<byte> enumVals = new();
foreach (var item in enumFields)
{
byte val = (byte)item;

// Skip ItemKind::NONE
if (val != 0)
{
enumVals.Add(val);
}
}

this.memberTypeModels = enumVals
.Zip(
unionType.GetGenericArguments(),
(enumVal, typeModel) => (enumVal, this.typeModelContainer.CreateTypeModel(typeModel)))
.ToArray();
}

public override void Validate()
{
base.Validate();
HashSet<Type> uniqueTypes = new HashSet<Type>();

foreach (var item in this.memberTypeModels)
foreach (var (_, item) in this.memberTypeModels)
{
FlatSharpInternal.Assert(
item.IsValidUnionMember,
Expand Down
15 changes: 15 additions & 0 deletions src/Tests/FlatSharpCompilerTests/UnionTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -109,5 +109,20 @@ struct ValueStruct (fs_valueStruct) {{ x : int; }}
new()));

Assert.Contains("FlatSharp unions may not contain duplicate types. Union = UnionTests.MyUnion", ex.Message);
}

[Fact]
public void Union_WithCustomEnumValue()
{
string schema = @"
namespace UnionTests;
table A {}
table B {}
union C { A = 2, B = 4 }
";

FlatSharpCompiler.CompileAndLoadAssembly(
schema,
new());
}
}

0 comments on commit 2377038

Please sign in to comment.