Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fixed union generator when processing union with custom enum value #447

Merged
merged 5 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions src/FlatSharp.Compiler/SchemaModel/ReferenceUnionSchemaModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ public class ReferenceUnionSchemaModel : BaseSchemaModel
protected override void OnWriteCode(CodeWriter writer, CompileContext context)
{
List<(string resolvedType, EnumVal value, Type? propertyType)> innerTypes = new();
int itemIndex = 1;
foreach (var inner in this.union.Values.Select(x => x.Value))
{
// Skip "none".
Expand All @@ -61,8 +62,9 @@ protected override void OnWriteCode(CodeWriter writer, CompileContext context)
FlatSharpInternal.Assert(previousType is not null, "PreviousType was null");

propertyClrType = previousType
.GetProperty($"Item{inner.Value}", BindingFlags.Public | BindingFlags.Instance)?
.GetProperty($"Item{itemIndex}", BindingFlags.Public | BindingFlags.Instance)?
.PropertyType;
++itemIndex;

FlatSharpInternal.Assert(propertyClrType is not null, "Couldn't find property");
}
Expand Down Expand Up @@ -98,11 +100,13 @@ protected override void OnWriteCode(CodeWriter writer, CompileContext context)
writer.AppendLine();
writer.AppendLine("public byte Discriminator => (byte)this.discriminator;");

int index = 1;
foreach (var item in innerTypes)
{
this.WriteConstructor(writer, item.resolvedType, item.value, item.propertyType);
this.AddUnionMember(writer, item.resolvedType, item.value, item.propertyType, context);
this.WriteConstructor(writer, index, item.resolvedType, item.value, item.propertyType);
this.AddUnionMember(writer, index, item.resolvedType, item.value, item.propertyType, context);
this.WriteImplicitOperator(writer, item.resolvedType);
++index;
}

this.WriteDefaultConstructor(writer);
Expand All @@ -112,13 +116,13 @@ protected override void OnWriteCode(CodeWriter writer, CompileContext context)
}
}

private void AddUnionMember(CodeWriter writer, string resolvedType, EnumVal value, Type? propertyClrType, CompileContext context)
private void AddUnionMember(CodeWriter writer, int index, string resolvedType, EnumVal value, Type? propertyClrType, CompileContext context)
{
writer.AppendLine();
writer.AppendLine($"private {resolvedType}{(propertyClrType?.IsValueType == false ? "?" : string.Empty)} value_{value.Value};");

writer.AppendLine();
writer.AppendLine($"public {resolvedType} Item{value.Value}");
writer.AppendLine($"public {resolvedType} Item{index}");
using (writer.WithBlock())
{
writer.AppendLine("get");
Expand Down Expand Up @@ -233,7 +237,7 @@ private void WriteMatchMethod(
}
}

private void WriteConstructor(CodeWriter writer, string resolvedType, EnumVal unionValue, Type? propertyType)
private void WriteConstructor(CodeWriter writer, int index, string resolvedType, EnumVal unionValue, Type? propertyType)
{
writer.AppendLine($"public {this.Name}({resolvedType} value)");
using (writer.WithBlock())
Expand All @@ -248,7 +252,7 @@ private void WriteConstructor(CodeWriter writer, string resolvedType, EnumVal un
}

writer.AppendLine($"this.discriminator = {unionValue.Value};");
writer.AppendLine($"this.Item{unionValue.Value} = value;");
writer.AppendLine($"this.Item{index} = value;");
}
}

Expand Down
33 changes: 20 additions & 13 deletions src/FlatSharp.Compiler/SchemaModel/ValueUnionSchemaModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ protected override void OnWriteCode(CodeWriter writer, CompileContext context)
writer.AppendLine();
writer.AppendLine("public byte Discriminator { get; }");

int index = 1;
foreach (var item in innerTypes)
{
Type? propertyClrType = null;
Expand All @@ -168,21 +169,21 @@ protected override void OnWriteCode(CodeWriter writer, CompileContext context)
FlatSharpInternal.Assert(previousType is not null, "PreviousType was null");

propertyClrType = previousType
.GetProperty($"Item{item.value.Value}", System.Reflection.BindingFlags.Public | System.Reflection.BindingFlags.Instance)?
.GetProperty($"Item{index}", System.Reflection.BindingFlags.Public | System.Reflection.BindingFlags.Instance)?
.PropertyType;

FlatSharpInternal.Assert(propertyClrType is not null, "Couldn't find property");
}

this.WriteConstructor(writer, item.resolvedType, item.value, propertyClrType, generateUnsafeItems);
this.WriteImplicitOperator(writer, item.resolvedType);
this.WriteUncheckedGetItemMethod(writer, item.resolvedType, item.value, propertyClrType, generateUnsafeItems);
this.WriteUncheckedGetItemMethod(writer, index, item.resolvedType, item.value, propertyClrType, generateUnsafeItems);

writer.AppendLine();
writer.AppendLine($"public {item.resolvedType} {item.value.Key} => this.Item{item.value.Value};");
writer.AppendLine($"public {item.resolvedType} {item.value.Key} => this.Item{index};");

writer.AppendLine();
writer.AppendLine($"public {item.resolvedType} Item{item.value.Value}");
writer.AppendLine($"public {item.resolvedType} Item{index}");
using (writer.WithBlock())
{
writer.AppendLine("get");
Expand All @@ -194,7 +195,7 @@ protected override void OnWriteCode(CodeWriter writer, CompileContext context)
writer.AppendLine($"{typeof(FSThrow).GGCTN()}.{nameof(FSThrow.InvalidOperation_UnionIsNotOfType)}();");
}

writer.AppendLine($"return this.UncheckedGetItem{item.value.Value}();");
writer.AppendLine($"return this.UncheckedGetItem{index}();");
}
}

Expand Down Expand Up @@ -225,9 +226,11 @@ protected override void OnWriteCode(CodeWriter writer, CompileContext context)
writer.AppendLine("return false;");
}

writer.AppendLine($"value = this.UncheckedGetItem{item.value.Value}();");
writer.AppendLine($"value = this.UncheckedGetItem{index}();");
writer.AppendLine("return true;");
}

++index;
}

this.WriteAcceptMethod(writer, innerTypes);
Expand All @@ -253,10 +256,12 @@ private void WriteAcceptMethod(
writer.AppendLine("switch (disc)");
using (writer.WithBlock())
{
int index = 1;
foreach (var item in components)
{
long index = item.value.Value;
writer.AppendLine($"case {index}: return visitor.Visit(this.UncheckedGetItem{item.value.Value}());");
long value = item.value.Value;
writer.AppendLine($"case {value}: return visitor.Visit(this.UncheckedGetItem{index}());");
++index;
}

writer.AppendLine($"default:");
Expand Down Expand Up @@ -286,10 +291,12 @@ private void WriteMatchMethod(
writer.AppendLine("switch (disc)");
using (writer.WithBlock())
{
int index = 1;
foreach (var item in components)
{
long index = item.value.Value;
writer.AppendLine($"case {index}: return case{item.value.Key}(this.UncheckedGetItem{item.value.Value}());");
long value = item.value.Value;
writer.AppendLine($"case {value}: return case{item.value.Key}(this.UncheckedGetItem{index}());");
++index;
}

writer.AppendLine($"default:");
Expand All @@ -302,12 +309,12 @@ private void WriteMatchMethod(
}
}

private void WriteUncheckedGetItemMethod(CodeWriter writer, string resolvedType, EnumVal unionValue, Type? propertyType, bool generateUnsafeItems)
private void WriteUncheckedGetItemMethod(CodeWriter writer, int index, string resolvedType, EnumVal unionValue, Type? propertyType, bool generateUnsafeItems)
{
if (propertyType?.IsValueType == true && generateUnsafeItems)
{
writer.AppendLine();
writer.AppendLine($"private {resolvedType} UncheckedGetItem{unionValue.Value}()");
writer.AppendLine($"private {resolvedType} UncheckedGetItem{index}()");
using (writer.WithBlock())
{
writer.AppendLine($"FlatSharpInternal.AssertSizeOf<{resolvedType}>({propertyType.StructLayoutAttribute!.Size});");
Expand All @@ -322,7 +329,7 @@ private void WriteUncheckedGetItemMethod(CodeWriter writer, string resolvedType,
else
{
writer.AppendLine();
writer.AppendLine($"private {resolvedType} UncheckedGetItem{unionValue.Value}()");
writer.AppendLine($"private {resolvedType} UncheckedGetItem{index}()");
using (writer.WithBlock())
{
writer.AppendLine($"return ({resolvedType})this.value;");
Expand Down
62 changes: 41 additions & 21 deletions src/FlatSharp/TypeModel/UnionTypeModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ namespace FlatSharp.TypeModel;
/// </summary>
public class UnionTypeModel : RuntimeTypeModel
{
private ITypeModel[] memberTypeModels;
private (byte, ITypeModel)[] memberTypeModels;

internal UnionTypeModel(Type unionType, TypeModelContainer provider) : base(unionType, provider)
{
Expand Down Expand Up @@ -70,25 +70,25 @@ internal UnionTypeModel(Type unionType, TypeModelContainer provider) : base(unio
/// <summary>
/// Gets the type model for this union's members. Index 0 corresponds to discriminator 1.
/// </summary>
public ITypeModel[] UnionElementTypeModel => this.memberTypeModels;
public (byte, ITypeModel)[] UnionElementTypeModel => this.memberTypeModels;

/// <summary>
/// We need it to pass through.
/// </summary>
public override TableFieldContextRequirements TableFieldContextRequirements =>
this.memberTypeModels.Select(x => x.TableFieldContextRequirements).Aggregate(TableFieldContextRequirements.None, (a, b) => a | b);
this.memberTypeModels.Select(model => model.Item2).Select(x => x.TableFieldContextRequirements).Aggregate(TableFieldContextRequirements.None, (a, b) => a | b);

/// <summary>
/// Unions have an implicit dependency on <see cref="byte"/> for the discriminator.
/// </summary>
public override IEnumerable<ITypeModel> Children => this.memberTypeModels.Concat(new[] { this.typeModelContainer.CreateTypeModel(typeof(byte)) });
public override IEnumerable<ITypeModel> Children => this.memberTypeModels.Select(model => model.Item2).Concat(new[] { this.typeModelContainer.CreateTypeModel(typeof(byte)) });

public override CodeGeneratedMethod CreateGetMaxSizeMethodBody(GetMaxSizeCodeGenContext context)
{
List<string> switchCases = new List<string>();
for (int i = 0; i < this.UnionElementTypeModel.Length; ++i)
{
var unionMember = this.UnionElementTypeModel[i];
var (enumVal, unionMember) = this.UnionElementTypeModel[i];
int unionIndex = i + 1;

var itemContext = context with
Expand All @@ -98,7 +98,7 @@ public override CodeGeneratedMethod CreateGetMaxSizeMethodBody(GetMaxSizeCodeGen

string @case =
$@"
case {unionIndex}:
case {enumVal}:
return {sizeof(uint) + SerializationHelpers.GetMaxPadding(sizeof(uint))} + {itemContext.GetMaxSizeInvocation(unionMember.ClrType)};";

switchCases.Add(@case);
Expand Down Expand Up @@ -127,7 +127,7 @@ public override CodeGeneratedMethod CreateParseMethodBody(ParserCodeGenContext c

for (int i = 0; i < this.UnionElementTypeModel.Length; ++i)
{
var unionMember = this.UnionElementTypeModel[i];
var (enumVal, unionMember) = this.UnionElementTypeModel[i];
int unionIndex = i + 1;

string inlineAdjustment = string.Empty;
Expand All @@ -144,7 +144,7 @@ public override CodeGeneratedMethod CreateParseMethodBody(ParserCodeGenContext c

string @case =
$@"
case {unionIndex}:
case {enumVal}:
{inlineAdjustment}
return {createNew}({itemContext.GetParseInvocation(unionMember.ClrType)});
";
Expand Down Expand Up @@ -182,8 +182,9 @@ public override CodeGeneratedMethod CreateParseMethodBody(ParserCodeGenContext c

for (int i = 0; i < this.UnionElementTypeModel.Length; ++i)
{
int unionIndex = i + 1; // unions start at 1.
string itemType = this.UnionElementTypeModel[i].GetGlobalCompilableTypeName();
int unionIndex = i + 1;
var (enumVal, unionMember) = this.UnionElementTypeModel[i];
string itemType = unionMember.GetGlobalCompilableTypeName();

getOrCreates.Add($@"
public static {className} GetOrCreate({itemType} value)
Expand All @@ -193,7 +194,7 @@ public override CodeGeneratedMethod CreateParseMethodBody(ParserCodeGenContext c
union = new {className}();
}}

union.discriminator = {unionIndex};
union.discriminator = {enumVal};
union.Item{unionIndex} = value;
union.isAlive = 1;

Expand All @@ -202,13 +203,13 @@ public override CodeGeneratedMethod CreateParseMethodBody(ParserCodeGenContext c
");

string recursiveReturn = string.Empty;
if (typeof(IPoolableObject).IsAssignableFrom(this.UnionElementTypeModel[i].ClrType))
if (typeof(IPoolableObject).IsAssignableFrom(unionMember.ClrType))
{
recursiveReturn = $"this.Item{unionIndex}?.ReturnToPool(true);";
}

returnToPoolCases.Add($@"
case {unionIndex}:
case {enumVal}:
{{
{recursiveReturn}
this.Item{unionIndex} = default({itemType})!;
Expand Down Expand Up @@ -259,7 +260,7 @@ public override CodeGeneratedMethod CreateSerializeMethodBody(SerializationCodeG
List<string> switchCases = new List<string>();
for (int i = 0; i < this.UnionElementTypeModel.Length; ++i)
{
var elementModel = this.UnionElementTypeModel[i];
var (enumVal, elementModel) = this.UnionElementTypeModel[i];
var unionIndex = i + 1;

string inlineAdjustment;
Expand Down Expand Up @@ -287,7 +288,7 @@ public override CodeGeneratedMethod CreateSerializeMethodBody(SerializationCodeG

string @case =
$@"
case {unionIndex}:
case {enumVal}:
{{
{inlineAdjustment}
{caseContext.GetSerializeInvocation(elementModel.ClrType)};
Expand Down Expand Up @@ -321,11 +322,12 @@ public override CodeGeneratedMethod CreateCloneMethodBody(CloneCodeGenContext co

for (int i = 0; i < this.memberTypeModels.Length; ++i)
{
int discriminator = i + 1;
string cloneMethod = context.MethodNameMap[this.memberTypeModels[i].ClrType];
int index = i + 1;
var (enumVal, unionMember) = this.UnionElementTypeModel[i];
var cloneMethod = context.MethodNameMap[unionMember.ClrType];
switchCases.Add($@"
case {discriminator}:
return new {this.GetGlobalCompilableTypeName()}({cloneMethod}({context.ItemVariableName}.Item{discriminator}));
case {enumVal}:
return new {this.GetGlobalCompilableTypeName()}({cloneMethod}({context.ItemVariableName}.Item{index}));
");
}

Expand All @@ -349,15 +351,33 @@ public override void Initialize()
Type unionType = this.ClrType.GetInterfaces()
.Single(x => x != typeof(IFlatBufferUnion) && typeof(IFlatBufferUnion).IsAssignableFrom(x));

this.memberTypeModels = unionType.GetGenericArguments().Select(this.typeModelContainer.CreateTypeModel).ToArray();
// Get enum value in union type
var enumFields = Enum.GetValues(this.ClrType.GetNestedType("ItemKind")!);
List<byte> enumVals = new();
foreach (var item in enumFields)
{
byte val = (byte)item;

// Skip ItemKind::NONE
if (val != 0)
{
enumVals.Add(val);
}
}

this.memberTypeModels = enumVals
.Zip(
unionType.GetGenericArguments(),
(enumVal, typeModel) => (enumVal, this.typeModelContainer.CreateTypeModel(typeModel)))
.ToArray();
}

public override void Validate()
{
base.Validate();
HashSet<Type> uniqueTypes = new HashSet<Type>();

foreach (var item in this.memberTypeModels)
foreach (var (_, item) in this.memberTypeModels)
{
FlatSharpInternal.Assert(
item.IsValidUnionMember,
Expand Down
15 changes: 15 additions & 0 deletions src/Tests/FlatSharpCompilerTests/UnionTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -109,5 +109,20 @@ struct ValueStruct (fs_valueStruct) {{ x : int; }}
new()));

Assert.Contains("FlatSharp unions may not contain duplicate types. Union = UnionTests.MyUnion", ex.Message);
}

[Fact]
public void Union_WithCustomEnumValue()
{
string schema = @"
namespace UnionTests;
table A {}
sssooonnnggg marked this conversation as resolved.
Show resolved Hide resolved
table B {}
union C { A = 2, B = 4 }
sssooonnnggg marked this conversation as resolved.
Show resolved Hide resolved
";

FlatSharpCompiler.CompileAndLoadAssembly(
schema,
new());
}
}
Loading