Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Match with state - fixes #121 #133

Merged
merged 5 commits into from
Apr 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 188 additions & 1 deletion src/UnionGeneration/UnionSourceBuilder.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using System.Text;
using System.Text;

namespace Dunet.UnionGeneration;

Expand Down Expand Up @@ -122,6 +122,53 @@ UnionDeclaration union
builder.AppendLine(" );");
builder.AppendLine();

// public abstract TMatchOutput Match<TState, TMatchOutput>(
// TState state,
// System.Func<TState, UnionVariant1<T1, T2, ...>, TMatchOutput> @unionVariant1,
// System.Func<TState, UnionVariant2<T1, T2, ...>, TMatchOutput> @unionVariant2,
// ...
// );
builder.AppendLine(" public abstract TMatchOutput Match<TState, TMatchOutput>(");
builder.Append($" TState state");
builder.AppendLine(union.Variants.Count > 0 ? "," : string.Empty);
for (int i = 0; i < union.Variants.Count; ++i)
{
var variant = union.Variants[i];
builder.Append($" System.Func<TState, {variant.Identifier}");
builder.AppendTypeParams(variant.TypeParameters);
builder.Append($", TMatchOutput> {variant.Identifier.ToMethodParameterCase()}");
if (i < union.Variants.Count - 1)
{
builder.Append(",");
}
builder.AppendLine();
}
builder.AppendLine(" );");

// public abstract void Match<TState>(
// TState state,
// System.Action<TState, UnionVariant1<T1, T2, ...>> @unionVariant1,
// System.Action<TState, UnionVariant2<T1, T2, ...>> @unionVariant2,
// ...
// );
builder.AppendLine(" public abstract void Match<TState>(");
builder.Append($" TState state");
builder.AppendLine(union.Variants.Count > 0 ? "," : string.Empty);
for (int i = 0; i < union.Variants.Count; ++i)
{
var variant = union.Variants[i];
builder.Append($" System.Action<TState, {variant.Identifier}");
builder.AppendTypeParams(variant.TypeParameters);
builder.Append($"> {variant.Identifier.ToMethodParameterCase()}");
if (i < union.Variants.Count - 1)
{
builder.Append(",");
}
builder.AppendLine();
}
builder.AppendLine(" );");
builder.AppendLine();

return builder;
}

Expand Down Expand Up @@ -164,6 +211,44 @@ UnionDeclaration union

builder.AppendLine();

foreach (var variant in union.Variants)
{
// public abstract TMatchOutput MatchSpecific<TState, TMatchOutput>(
// TState state,
// System.Func<TState, Specific<T1, T2, ...>, TMatchOutput> @specific,
// System.Func<TState, TMatchOutput> @else
// );
builder.AppendLine($" public abstract TMatchOutput Match{variant.Identifier}<TState, TMatchOutput>(");
builder.Append($" TState state");
builder.AppendLine(union.Variants.Count > 0 ? "," : string.Empty);
builder.Append($" System.Func<TState, {variant.Identifier}");
builder.AppendTypeParams(variant.TypeParameters);
builder.AppendLine($", TMatchOutput> {variant.Identifier.ToMethodParameterCase()},");
builder.AppendLine($" System.Func<TState, TMatchOutput> @else");
builder.AppendLine(" );");
}

builder.AppendLine();

foreach (var variant in union.Variants)
{
// public abstract void MatchSpecific<TState>(
// TState state,
// System.Action<TState, Specific<T1, T2, ...>> @specific,
// System.Action<TState> @else
// );
builder.AppendLine($" public abstract void Match{variant.Identifier}<TState>(");
builder.Append($" TState state");
builder.AppendLine(union.Variants.Count > 0 ? "," : string.Empty);
builder.Append($" System.Action<TState, {variant.Identifier}");
builder.AppendTypeParams(variant.TypeParameters);
builder.AppendLine($"> {variant.Identifier.ToMethodParameterCase()},");
builder.AppendLine($" System.Action<TState> @else");
builder.AppendLine(" );");
}

builder.AppendLine();

return builder;
}

Expand Down Expand Up @@ -213,6 +298,52 @@ VariantDeclaration variant
}
builder.AppendLine($" ) => {variant.Identifier.ToMethodParameterCase()}(this);");

// public override TMatchOutput Match<TState, TMatchOutput>(
// TState state,
// System.Func<TState, UnionVariant1<T1, T2, ...>, TMatchOutput> @unionVariant1,
// System.Func<TState, UnionVariant2<T1, T2, ...>, TMatchOutput> @unionVariant2,
// ...
// ) => unionVariantX(state, this);
builder.AppendLine(" public override TMatchOutput Match<TState, TMatchOutput>(");
builder.Append($" TState state");
builder.AppendLine(union.Variants.Count > 0 ? "," : string.Empty);
for (int i = 0; i < union.Variants.Count; ++i)
{
var variantParam = union.Variants[i];
builder.Append($" System.Func<TState, {variantParam.Identifier}");
builder.AppendTypeParams(variantParam.TypeParameters);
builder.Append($", TMatchOutput> {variantParam.Identifier.ToMethodParameterCase()}");
if (i < union.Variants.Count - 1)
{
builder.Append(",");
}
builder.AppendLine();
}
builder.AppendLine($" ) => {variant.Identifier.ToMethodParameterCase()}(state, this);");

// public override void Match<TState>(
// TState state,
// System.Action<TState, UnionVariant1<T1, T2, ...>> @unionVariant1,
// System.Action<TState, UnionVariant2<T1, T2, ...>> @unionVariant2,
// ...
// ) => unionVariantX(state, this);
builder.AppendLine(" public override void Match<TState>(");
builder.Append($" TState state");
builder.AppendLine(union.Variants.Count > 0 ? "," : string.Empty);
for (int i = 0; i < union.Variants.Count; ++i)
{
var variantParam = union.Variants[i];
builder.Append($" System.Action<TState, {variantParam.Identifier}");
builder.AppendTypeParams(variantParam.TypeParameters);
builder.Append($"> {variantParam.Identifier.ToMethodParameterCase()}");
if (i < union.Variants.Count - 1)
{
builder.Append(",");
}
builder.AppendLine();
}
builder.AppendLine($" ) => {variant.Identifier.ToMethodParameterCase()}(state, this);");

return builder;
}

Expand Down Expand Up @@ -272,6 +403,62 @@ VariantDeclaration variant
}
}

// public override TMatchOutput MatchVariantX<TState, TMatchOutput>(
// TState state,
// System.Func<TState, UnionVariant1<T1, T2, ...>, TMatchOutput> @unionVariantX,
// System.Func<TState, TMatchOutput> @else,
// ...
// ) => unionVariantX(state, this);
foreach (var specificVariant in union.Variants)
{
builder.AppendLine(
$" public override TMatchOutput Match{specificVariant.Identifier}<TState, TMatchOutput>("
);
builder.Append($" TState state");
builder.AppendLine(union.Variants.Count > 0 ? "," : string.Empty);
builder.Append($" System.Func<TState, {specificVariant.Identifier}");
builder.AppendTypeParams(specificVariant.TypeParameters);
builder.AppendLine(
$", TMatchOutput> {specificVariant.Identifier.ToMethodParameterCase()},"
);
builder.AppendLine($" System.Func<TState, TMatchOutput> @else");
builder.Append(" ) => ");
if (specificVariant.Identifier == variant.Identifier)
{
builder.AppendLine($"{specificVariant.Identifier.ToMethodParameterCase()}(state, this);");
}
else
{
builder.AppendLine("@else(state);");
}
}

// public override void MatchVariantX<TState>(
// TState state,
// System.Action<TState, UnionVariant1<T1, T2, ...>> @unionVariantX,
// System.Action<TState> @else,
// ...
// ) => unionVariantX(state, this);
foreach (var specificVariant in union.Variants)
{
builder.AppendLine($" public override void Match{specificVariant.Identifier}<TState>(");
builder.Append($" TState state");
builder.AppendLine(union.Variants.Count > 0 ? "," : string.Empty);
builder.Append($" System.Action<TState, {specificVariant.Identifier}");
builder.AppendTypeParams(specificVariant.TypeParameters);
builder.AppendLine($"> {specificVariant.Identifier.ToMethodParameterCase()},");
builder.AppendLine($" System.Action<TState> @else");
builder.Append(" ) => ");
if (specificVariant.Identifier == variant.Identifier)
{
builder.AppendLine($"{specificVariant.Identifier.ToMethodParameterCase()}(state, this);");
}
else
{
builder.AppendLine("@else(state);");
}
}

builder.AppendLine(" }");
builder.AppendLine();

Expand Down
127 changes: 127 additions & 0 deletions test/UnionGeneration/MatchMethodWithStateTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
namespace Dunet.Test.UnionGeneration;

public sealed class MatchMethodWithStateTests
{
[Fact]
public void CanUseUnionTypesInDedicatedMatchMethod()
{
// Arrange.
var source = """
using Dunet;

Shape shape = new Shape.Rectangle(3, 4);
double state = 2d;

var area = shape.Match(
state,
static (s, circle) => s + 3.14 * circle.Radius * circle.Radius,
static (s, rectangle) => s + rectangle.Length * rectangle.Width,
static (s, triangle) => s + triangle.Base * triangle.Height / 2
);

[Union]
partial record Shape
{
partial record Circle(double Radius);
partial record Rectangle(double Length, double Width);
partial record Triangle(double Base, double Height);
}
""";
// Act.
var result = Compiler.Compile(source);

// Assert.
using var scope = new AssertionScope();
result.CompilationErrors.Should().BeEmpty();
result.GenerationErrors.Should().BeEmpty();
}

[Theory]
[InlineData("Shape shape = new Shape.Rectangle(3, 4);", 14d)]
[InlineData("Shape shape = new Shape.Circle(1);", 5.14d)]
[InlineData("Shape shape = new Shape.Triangle(4, 2);", 6d)]
public void MatchMethodCallsCorrectFunctionArgument(
string shapeDeclaration,
double expectedArea
)
{
// Arrange.
var source = $$"""
using Dunet;

static double GetArea()
{
{{shapeDeclaration}}
double state = 2d;
return shape.Match(
state,
static (s, circle) => s + 3.14 * circle.Radius * circle.Radius,
static (s, rectangle) => s + rectangle.Length * rectangle.Width,
static (s, triangle) => s + triangle.Base * triangle.Height / 2
);
}

[Union]
partial record Shape
{
partial record Circle(double Radius);
partial record Rectangle(double Length, double Width);
partial record Triangle(double Base, double Height);
}
""";
// Act.
var result = Compiler.Compile(source);
var actualArea = result.Assembly?.ExecuteStaticMethod<double>("GetArea");

// Assert.
using var scope = new AssertionScope();
result.CompilationErrors.Should().BeEmpty();
result.GenerationErrors.Should().BeEmpty();
actualArea.Should().BeApproximately(expectedArea, 0.0000000001d);
}

[Theory]
[InlineData("Keyword keyword = new Keyword.New();" , "string state = \"new\";", "new")]
[InlineData("Keyword keyword = new Keyword.Base();", "string state = \"base\";", "base")]
[InlineData("Keyword keyword = new Keyword.Null();", "string state = \"null\";", "null")]
public void CanMatchOnUnionVariantsNamedAfterKeywords(
string keywordDeclaration,
string stateDeclaration,
string expectedKeyword
)
{
// Arrange.
var source = $$"""
using Dunet;

static string GetKeyword()
{
{{keywordDeclaration}}
{{stateDeclaration}}
return keyword.Match(
state,
static (s, @new) => s,
static (s, @base) => s,
static (s, @null) => s
);
}

[Union]
partial record Keyword
{
partial record New;
partial record Base;
partial record Null;
}
""";
// Act.
var result = Compiler.Compile(source);
var actualKeyword = result.Assembly?.ExecuteStaticMethod<string>("GetKeyword");

// Assert.
using var scope = new AssertionScope();
result.CompilationErrors.Should().BeEmpty();
result.GenerationErrors.Should().BeEmpty();
actualKeyword.Should().Be(expectedKeyword);
}
}
Loading