Skip to content

Commit

Permalink
Validate ProtoReserved attributes
Browse files Browse the repository at this point in the history
  • Loading branch information
ltrzesniewski committed Jan 11, 2024
1 parent ffacf28 commit 8dcf594
Show file tree
Hide file tree
Showing 7 changed files with 135 additions and 41 deletions.
20 changes: 20 additions & 0 deletions src/Abc.Zebus.MessageDsl.Tests/MessageDsl/ParsedContractsTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -783,6 +783,24 @@ public void should_accept_discards(string message)
ParseValid(message);
}

[Test]
[TestCase("[ProtoReserved(2)] Foo(int a)", ExpectedResult = true)]
[TestCase("[ProtoReserved(2)] Foo(int a, int b)", ExpectedResult = false)]
[TestCase("[ProtoReserved(2)] Foo(int a, _, int b)", ExpectedResult = true)]
[TestCase("[ProtoReserved(2), ProtoReserved(4)] Foo(int a, _, int b)", ExpectedResult = true)]
[TestCase("[ProtoReserved(2), ProtoReserved(4)] Foo(int a, _, int b, int c)", ExpectedResult = false)]
[TestCase("[ProtoReserved(2), ProtoReserved(4)] Foo(int a, _, int b, _, int c)", ExpectedResult = true)]
[TestCase("[ProtoReserved(2, 4)] Foo(int a, _, int b, _, int c)", ExpectedResult = false)]
[TestCase("[ProtoReserved(2, 4)] Foo(int a, _, _, _, int b)", ExpectedResult = true)]
[TestCase("[ProtoReserved(2, 4)] Foo(int a, [5] int b)", ExpectedResult = true)]
[TestCase("[ProtoReserved(2, 4)] Foo(int a, [3] int b)", ExpectedResult = false)]
[TestCase("[ProtoReserved(\"lol\")] Foo()", ExpectedResult = false)]
[TestCase("[ProtoReserved(2, \"lol\")] Foo()", ExpectedResult = true)]
[TestCase("[ProtoReserved(2, 4, \"lol\")] Foo()", ExpectedResult = true)]
[TestCase("[ProtoReserved(4, 2)] Foo()", ExpectedResult = false)]
public bool should_validate_proto_reserved_attributes(string definitionText)
=> Parse(definitionText).IsValid;

[Test]
public void should_generate_reservation_for_discards()
{
Expand Down Expand Up @@ -817,7 +835,9 @@ private static ParsedContracts ParseInvalid(string definitionText)

private static ParsedContracts Parse(string definitionText)
{
Console.WriteLine();
Console.WriteLine("PARSE: {0}", definitionText);

var contracts = ParsedContracts.Parse(definitionText, "Some.Namespace");

foreach (var error in contracts.Errors)
Expand Down
44 changes: 11 additions & 33 deletions src/Abc.Zebus.MessageDsl/Analysis/AstProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -154,45 +154,23 @@ private static void AddReservations(MessageDefinition message)
{
if (parameter.IsDiscarded)
{
if (currentReservation.TryAddTag(parameter.Tag))
continue;

currentReservation.AddAttributeIfNotEmpty(message);
currentReservation = new ReservationRange(parameter.Tag);
if (currentReservation.TryAddTag(parameter.Tag, out var updatedReservation))
{
currentReservation = updatedReservation;
}
else
{
currentReservation.AddToMessage(message);
currentReservation = new ReservationRange(parameter.Tag);
}
}
else
{
currentReservation.AddAttributeIfNotEmpty(message);
currentReservation.AddToMessage(message);
currentReservation = ReservationRange.None;
}
}

currentReservation.AddAttributeIfNotEmpty(message);
}

private struct ReservationRange(int startTag)
{
public static ReservationRange None => default;

private readonly int _startTag = startTag;
private int _endTag = startTag;

public bool TryAddTag(int tag)
{
if (_startTag < AstValidator.ProtoMinTag || tag != _endTag + 1)
return false;

_endTag = tag;
return true;
}

public void AddAttributeIfNotEmpty(MessageDefinition message)
{
if (_startTag >= AstValidator.ProtoMinTag)
message.Attributes.Add(new AttributeDefinition(KnownTypes.ProtoReservedAttribute, _startTag == _endTag ? $"{_startTag}" : $"{_startTag}, {_endTag}"));
}

public override string ToString()
=> _startTag >= AstValidator.ProtoMinTag ? $"{_startTag} - {_endTag}" : "None";
currentReservation.AddToMessage(message);
}
}
6 changes: 6 additions & 0 deletions src/Abc.Zebus.MessageDsl/Analysis/AstValidator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,19 @@ private void ValidateTags(MessageDefinition message)

foreach (var param in message.Parameters)
{
if (param.IsDiscarded)
continue;

var errorContext = param.ParseContext ?? message.ParseContext;

if (!IsValidTag(param.Tag))
_contracts.AddError(errorContext, $"Tag for parameter '{param.Name}' is not within the valid range ({param.Tag})");

if (!tags.Add(param.Tag))
_contracts.AddError(errorContext, $"Duplicate tag {param.Tag} on parameter {param.Name}");

if (message.ReservedRanges.Any(range => range.Contains(param.Tag)))
_contracts.AddError(errorContext, $"Tag {param.Tag} of parameter {param.Name} is reserved");
}

foreach (var attr in message.Attributes)
Expand Down
53 changes: 46 additions & 7 deletions src/Abc.Zebus.MessageDsl/Analysis/AttributeInterpreter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ namespace Abc.Zebus.MessageDsl.Analysis;
internal class AttributeInterpreter
{
private static readonly Regex _reProtoIncludeParams = new(@"^\s*(?<tag>[0-9]+)\s*,\s*typeof\s*\(\s*(?<typeName>.+)\s*\)", RegexOptions.Compiled | RegexOptions.CultureInvariant);
private static readonly Regex _reProtoReservedParams = new(@"^\s*(?<startTag>[0-9]+)(?:\s*,\s*(?<endTag>[0-9]+))?", RegexOptions.Compiled | RegexOptions.CultureInvariant);

private readonly ParsedContracts _contracts;

Expand All @@ -17,12 +18,13 @@ public AttributeInterpreter(ParsedContracts contracts)

public void InterpretAttributes()
{
foreach (var messageDefinition in _contracts.Messages)
foreach (var message in _contracts.Messages)
{
CheckIfTransient(messageDefinition);
CheckIfRoutable(messageDefinition);
CheckIfTransient(message);
CheckIfRoutable(message);
ProcessProtoReservedAttributes(message);

foreach (var parameterDefinition in messageDefinition.Parameters)
foreach (var parameterDefinition in message.Parameters)
ProcessProtoMemberAttribute(parameterDefinition);
}
}
Expand Down Expand Up @@ -65,7 +67,7 @@ private void CheckIfRoutable(MessageDefinition message)
.SequenceEqual(Enumerable.Range(1, routableParamCount));

if (routableParamCount == 0)
_contracts.AddError(message.ParseContext, "A routable message must have parameters with routing positions");
_contracts.AddError(message.ParseContext, "A routable message must have arguments with routing positions");
else if (!isValidSequence)
_contracts.AddError(message.ParseContext, "Routing positions must form a continuous sequence starting with 1");
}
Expand All @@ -85,6 +87,43 @@ private static void CheckIfTransient(MessageDefinition message)
message.IsTransient = !message.IsCustom && message.Attributes.HasAttribute(KnownTypes.TransientAttribute);
}

private void ProcessProtoReservedAttributes(MessageDefinition message)
{
foreach (var attr in message.Attributes.GetAttributes(KnownTypes.ProtoReservedAttribute))
{
var errorContext = attr.ParseContext ?? message.ParseContext;

if (string.IsNullOrWhiteSpace(attr.Parameters))
{
_contracts.AddError(errorContext, $"The [{KnownTypes.ProtoReservedAttribute}] attribute must have arguments");
return;
}

var match = _reProtoReservedParams.Match(attr.Parameters);
if (!match.Success || !int.TryParse(match.Groups["startTag"].Value, out var startTag))
{
_contracts.AddError(errorContext, $"Invalid [{KnownTypes.ProtoReservedAttribute}] arguments");
return;
}

var endTag = startTag;

if (match.Groups["endTag"].Success && !int.TryParse(match.Groups["endTag"].Value, out endTag))
{
_contracts.AddError(errorContext, $"Invalid [{KnownTypes.ProtoReservedAttribute}] arguments");
return;
}

if (startTag > endTag)
{
_contracts.AddError(errorContext, $"Invalid [{KnownTypes.ProtoReservedAttribute}] tag range");
return;
}

message.ReservedRanges.Add(new ReservationRange(startTag, endTag));
}
}

private void ProcessProtoMemberAttribute(ParameterDefinition param)
{
var attr = param.Attributes.GetAttribute(KnownTypes.ProtoMemberAttribute);
Expand All @@ -93,14 +132,14 @@ private void ProcessProtoMemberAttribute(ParameterDefinition param)

if (string.IsNullOrWhiteSpace(attr.Parameters))
{
_contracts.AddError(attr.ParseContext, $"The [{KnownTypes.ProtoMemberAttribute}] attribute must have parameters");
_contracts.AddError(attr.ParseContext, $"The [{KnownTypes.ProtoMemberAttribute}] attribute must have arguments");
return;
}

var match = Regex.Match(attr.Parameters, @"^\s*(?<nb>[0-9]+)\s*(?:,|$)");
if (!match.Success || !int.TryParse(match.Groups["nb"].Value, out var tagNb))
{
_contracts.AddError(attr.ParseContext, $"Invalid [{KnownTypes.ProtoMemberAttribute}] parameters");
_contracts.AddError(attr.ParseContext, $"Invalid [{KnownTypes.ProtoMemberAttribute}] arguments");
return;
}

Expand Down
46 changes: 46 additions & 0 deletions src/Abc.Zebus.MessageDsl/Analysis/ReservationRange.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
using Abc.Zebus.MessageDsl.Ast;

namespace Abc.Zebus.MessageDsl.Analysis;

internal readonly struct ReservationRange(int startTag, int endTag)
{
public static ReservationRange None => default;

private bool IsValid => startTag > 0;

public ReservationRange(int tag)
: this(tag, tag)
{
}

public bool Contains(int tag)
=> IsValid && tag >= startTag && tag <= endTag;

public bool TryAddTag(int tag, out ReservationRange updatedRange)
{
if (IsValid && tag == endTag + 1)
{
updatedRange = new ReservationRange(startTag, tag);
return true;
}

updatedRange = default;
return false;
}

public void AddToMessage(MessageDefinition message)
{
if (!IsValid)
return;

message.Attributes.Add(new AttributeDefinition(KnownTypes.ProtoReservedAttribute, startTag == endTag ? $"{startTag}" : $"{startTag}, {endTag}"));
message.ReservedRanges.Add(this);
}

public override string ToString()
=> IsValid
? startTag != endTag
? $"{startTag} - {endTag}"
: $"{startTag}"
: "None";
}
5 changes: 4 additions & 1 deletion src/Abc.Zebus.MessageDsl/Ast/AttributeSet.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@ public class AttributeSet : AstNode, IList<AttributeDefinition>
public IList<AttributeDefinition> Attributes { get; } = new List<AttributeDefinition>();

public AttributeDefinition? GetAttribute(TypeName attributeType)
=> Attributes.Count != 0 ? GetAttributes(attributeType).FirstOrDefault() : null;

public IEnumerable<AttributeDefinition> GetAttributes(TypeName attributeType)
{
attributeType = AttributeDefinition.NormalizeAttributeTypeName(attributeType);
return Attributes.FirstOrDefault(attr => Equals(attr.TypeName, attributeType));
return Attributes.Where(attr => Equals(attr.TypeName, attributeType));
}

public bool HasAttribute(TypeName attributeType)
Expand Down
2 changes: 2 additions & 0 deletions src/Abc.Zebus.MessageDsl/Ast/MessageDefinition.cs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ public MessageType Type
}
}

internal List<ReservationRange> ReservedRanges { get; } = new();

public override string ToString()
=> Name;
}

0 comments on commit 8dcf594

Please sign in to comment.