diff --git a/src/ApiService/ApiService/onefuzzlib/orm/EntityConverter.cs b/src/ApiService/ApiService/onefuzzlib/orm/EntityConverter.cs index 4ed195579d..d396a76fa4 100644 --- a/src/ApiService/ApiService/onefuzzlib/orm/EntityConverter.cs +++ b/src/ApiService/ApiService/onefuzzlib/orm/EntityConverter.cs @@ -88,6 +88,7 @@ public override string ConvertName(string name) { } public class EntityConverter { + private const int MAX_DESERIALIZATION_RECURSION_DEPTH = 100; private readonly ConcurrentDictionary _cache; private static readonly JsonSerializerOptions _options = new() { PropertyNamingPolicy = new OnefuzzNamingPolicy(), @@ -124,8 +125,8 @@ public static JsonSerializerOptions GetJsonSerializerOptions() { } private static IEnumerable GetEntityProperties(ParameterInfo parameterInfo) { - var name = parameterInfo.Name.EnsureNotNull($"Invalid paramter {parameterInfo}"); - var parameterType = parameterInfo.ParameterType.EnsureNotNull($"Invalid paramter {parameterInfo}"); + var name = parameterInfo.Name.EnsureNotNull($"Invalid parameter {parameterInfo}"); + var parameterType = parameterInfo.ParameterType.EnsureNotNull($"Invalid parameter {parameterInfo}"); var isRowkey = parameterInfo.GetCustomAttribute(typeof(RowKeyAttribute)) != null; var isPartitionkey = parameterInfo.GetCustomAttribute(typeof(PartitionKeyAttribute)) != null; @@ -135,7 +136,7 @@ private static IEnumerable GetEntityProperties(ParameterInfo (TypeDiscrimnatorAttribute, ITypeProvider)? discriminator = null; if (discriminatorAttribute != null) { - var t = (ITypeProvider)(Activator.CreateInstance(discriminatorAttribute.ConverterType) ?? throw new Exception("unable to retrive the type provider")); + var t = (ITypeProvider)(Activator.CreateInstance(discriminatorAttribute.ConverterType) ?? throw new Exception("unable to retrieve the type provider")); discriminator = (discriminatorAttribute, t); } @@ -222,7 +223,7 @@ public TableEntity ToTableEntity(T typedEntity) where T : EntityBase { } - private object? GetFieldValue(EntityInfo info, string name, TableEntity entity) { + private object? GetFieldValue(EntityInfo info, string name, TableEntity entity, int iterationCount) { var ef = info.properties[name].First(); if (ef.kind == EntityPropertyKind.PartitionKey || ef.kind == EntityPropertyKind.RowKey) { // partition & row keys must always be strings @@ -285,7 +286,23 @@ public TableEntity ToTableEntity(T typedEntity) where T : EntityBase { var outputType = ef.type; if (ef.discriminator != null) { var (attr, typeProvider) = ef.discriminator.Value; - var v = GetFieldValue(info, attr.FieldName, entity) ?? throw new Exception($"No value for {attr.FieldName}"); + if (iterationCount > MAX_DESERIALIZATION_RECURSION_DEPTH) { + var tags = GenerateTableEntityTags(entity); + tags.AddRange(new (string, string)[] { + ("outputType", outputType?.Name ?? string.Empty), + ("fieldName", fieldName) + }); + throw new OrmMaxRecursionDepthReachedException($"MAX_DESERIALIZATION_RECURSION_DEPTH reached. Too many iterations deserializing {info.type}. {PrintTags(tags)}"); + } + if (attr.FieldName == name) { + var tags = GenerateTableEntityTags(entity); + tags.AddRange(new (string, string)[] { + ("outputType", outputType?.Name ?? string.Empty), + ("fieldName", fieldName) + }); + throw new OrmInvalidDiscriminatorFieldException($"Discriminator field is the same as the field being deserialized {name}. {PrintTags(tags)}"); + } + var v = GetFieldValue(info, attr.FieldName, entity, ++iterationCount) ?? throw new Exception($"No value for {attr.FieldName}"); outputType = typeProvider.GetTypeInfo(v); } @@ -302,8 +319,13 @@ public TableEntity ToTableEntity(T typedEntity) where T : EntityBase { return JsonSerializer.Deserialize(value, outputType, options: _options); } } - } catch (Exception ex) { - throw new InvalidOperationException($"Unable to get value for property '{name}' (entity field '{fieldName}')", ex); + } catch (Exception ex) + when (ex is not OrmException) { + var tags = GenerateTableEntityTags(entity); + tags.AddRange(new (string, string)[] { + ("fieldName", fieldName) + }); + throw new InvalidOperationException($"Unable to get value for property '{name}' (entity field '{fieldName}'). {PrintTags(tags)}", ex); } } @@ -313,7 +335,7 @@ public T ToRecord(TableEntity entity) where T : EntityBase { object?[] parameters; try { - parameters = entityInfo.properties.Select(grouping => GetFieldValue(entityInfo, grouping.Key, entity)).ToArray(); + parameters = entityInfo.properties.Select(grouping => GetFieldValue(entityInfo, grouping.Key, entity, 0)).ToArray(); } catch (Exception ex) { throw new InvalidOperationException($"Unable to extract properties from TableEntity for {typeof(T)}", ex); } @@ -361,6 +383,31 @@ public T ToRecord(TableEntity entity) where T : EntityBase { return Expression.Lambda>(call, paramter).Compile(); } + private static List<(string, string)> GenerateTableEntityTags(TableEntity entity) { + var entityKeys = string.Join(',', entity.Keys); + var partitionKey = entity.ContainsKey(EntityPropertyKind.PartitionKey.ToString()) ? entity.GetString(EntityPropertyKind.PartitionKey.ToString()) : string.Empty; + var rowKey = entity.ContainsKey(EntityPropertyKind.RowKey.ToString()) ? entity.GetString(EntityPropertyKind.RowKey.ToString()) : string.Empty; + return new List<(string, string)> { + ("entityKeys", entityKeys), + ("partitionKey", partitionKey), + ("rowKey", rowKey) + }; + } + + private static string PrintTags(List<(string, string)>? tags) { + return tags != null ? string.Join(", ", tags.Select(x => $"{x.Item1}={x.Item2}")) : string.Empty; + } +} + +public class OrmInvalidDiscriminatorFieldException : OrmException { + public OrmInvalidDiscriminatorFieldException(string message) : base(message) { } +} + +public class OrmMaxRecursionDepthReachedException : OrmException { + public OrmMaxRecursionDepthReachedException(string message) : base(message) { } +} +public class OrmException : Exception { + public OrmException(string message) : base(message) { } } diff --git a/src/ApiService/Tests/OrmTest.cs b/src/ApiService/Tests/OrmTest.cs index 5c8f55e712..fb55e00849 100644 --- a/src/ApiService/Tests/OrmTest.cs +++ b/src/ApiService/Tests/OrmTest.cs @@ -1,15 +1,17 @@ using System; +using System.Linq; using System.Text.Json; using System.Text.Json.Nodes; using System.Text.Json.Serialization; using Azure.Data.Tables; +using FluentAssertions; using Microsoft.OneFuzz.Service; using Microsoft.OneFuzz.Service.OneFuzzLib.Orm; +using Moq; using Xunit; namespace Tests { public class OrmTest { - sealed class TestObject { public String? TheName { get; set; } public TestEnum TheEnum { get; set; } @@ -410,5 +412,36 @@ public void TestKeyGetters() { Assert.Equal(test.PartitionKey, actualPartitionKey); Assert.Equal(test.RowKey, actualRowKey); } + + sealed record NestedEntity( + [PartitionKey] int Id, + [RowKey] string TheName, + [property: TypeDiscrimnatorAttribute("EventType", typeof(EventTypeProvider))] + [property: JsonConverter(typeof(BaseEventConverter))] + Nested? EventType + ) : EntityBase(); + +#pragma warning disable CS0169 + public record Nested( + bool? B, + Nested? EventType + ) : BaseEvent(); +#pragma warning restore CS0169 + + [Fact] + public void TestDeeplyNestedObjects() { + var converter = new EntityConverter(); + var deeplyNestedJson = $"{{{string.Concat(Enumerable.Repeat("\"EventType\": {", 3))}{new String('}', 3)}}}"; // {{{...}}} + var nestedEntity = new NestedEntity( + Id: 123, + TheName: "abc", + EventType: JsonSerializer.Deserialize(deeplyNestedJson, new JsonSerializerOptions()) + ); + + var tableEntity = converter.ToTableEntity(nestedEntity); + var toRecord = () => converter.ToRecord(tableEntity); + + _ = toRecord.Should().Throw().And.InnerException!.Should().BeOfType(); + } } }