Skip to content
This repository has been archived by the owner on Nov 1, 2023. It is now read-only.

Cap recursion in ORM #2992

Merged
merged 16 commits into from
Apr 12, 2023
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/ApiService/ApiService/Log.cs
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ public void Flush() {
}

//TODO: Should we write errors and Exception to std err ?
sealed class Console : ILog {
public sealed class Console : ILog {

private static string DictToString<T>(IReadOnlyDictionary<string, T>? d) {
if (d is null) {
Expand Down
54 changes: 47 additions & 7 deletions src/ApiService/ApiService/onefuzzlib/orm/EntityConverter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ public override string ConvertName(string name) {
}
public class EntityConverter {

private const int MAX_DESERIALIZATION_RECURSION_DEPTH = 100;
private readonly ILogTracer _logTracer;
private readonly ConcurrentDictionary<Type, EntityInfo> _cache;
private static readonly JsonSerializerOptions _options = new() {
PropertyNamingPolicy = new OnefuzzNamingPolicy(),
Expand All @@ -97,8 +99,9 @@ public class EntityConverter {
}
};

public EntityConverter() {
public EntityConverter(ILogTracer logTracer) {
_cache = new ConcurrentDictionary<Type, EntityInfo>();
_logTracer = logTracer;
}

public static JsonSerializerOptions GetJsonSerializerOptions() {
Expand All @@ -124,8 +127,8 @@ public static JsonSerializerOptions GetJsonSerializerOptions() {
}

private static IEnumerable<EntityProperty> GetEntityProperties<T>(ParameterInfo parameterInfo) {
var name = parameterInfo.Name.EnsureNotNull($"Invalid paramter {parameterInfo}");
var parameterType = parameterInfo.ParameterType.EnsureNotNull($"Invalid paramter {parameterInfo}");
var name = parameterInfo.Name.EnsureNotNull($"Invalid paramater {parameterInfo}");
tevoinea marked this conversation as resolved.
Show resolved Hide resolved
var parameterType = parameterInfo.ParameterType.EnsureNotNull($"Invalid paramater {parameterInfo}");
tevoinea marked this conversation as resolved.
Show resolved Hide resolved
var isRowkey = parameterInfo.GetCustomAttribute(typeof(RowKeyAttribute)) != null;
var isPartitionkey = parameterInfo.GetCustomAttribute(typeof(PartitionKeyAttribute)) != null;

Expand All @@ -135,7 +138,7 @@ private static IEnumerable<EntityProperty> GetEntityProperties<T>(ParameterInfo

(TypeDiscrimnatorAttribute, ITypeProvider)? discriminator = null;
if (discriminatorAttribute != null) {
var t = (ITypeProvider)(Activator.CreateInstance(discriminatorAttribute.ConverterType) ?? throw new Exception("unable to retrive the type provider"));
var t = (ITypeProvider)(Activator.CreateInstance(discriminatorAttribute.ConverterType) ?? throw new Exception("unable to retrieve the type provider"));
discriminator = (discriminatorAttribute, t);
}

Expand Down Expand Up @@ -222,7 +225,7 @@ public TableEntity ToTableEntity<T>(T typedEntity) where T : EntityBase {
}


private object? GetFieldValue(EntityInfo info, string name, TableEntity entity) {
private object? GetFieldValue(EntityInfo info, string name, TableEntity entity, int iterationCount = 0) {
tevoinea marked this conversation as resolved.
Show resolved Hide resolved
var ef = info.properties[name].First();
if (ef.kind == EntityPropertyKind.PartitionKey || ef.kind == EntityPropertyKind.RowKey) {
// partition & row keys must always be strings
Expand Down Expand Up @@ -285,7 +288,25 @@ public TableEntity ToTableEntity<T>(T typedEntity) where T : EntityBase {
var outputType = ef.type;
if (ef.discriminator != null) {
var (attr, typeProvider) = ef.discriminator.Value;
var v = GetFieldValue(info, attr.FieldName, entity) ?? throw new Exception($"No value for {attr.FieldName}");
if (iterationCount > MAX_DESERIALIZATION_RECURSION_DEPTH) {
var tags = GenerateTableEntityTags(entity);
tags.AddRange(new (string, string)[] {
("outputType", outputType?.Name ?? string.Empty),
("fieldName", fieldName)
});
_logTracer.WithTags(tags).Error($"Too many iterations deserializing {info.type}");
tevoinea marked this conversation as resolved.
Show resolved Hide resolved
throw new OrmShortCircuitInfiniteLoopException("MAX_DESERIALIZATION_RECURSION_DEPTH reached");
}
if (attr.FieldName == name) {
var tags = GenerateTableEntityTags(entity);
tags.AddRange(new (string, string)[] {
("outputType", outputType?.Name ?? string.Empty),
("fieldName", fieldName)
});
_logTracer.WithTags(tags).Error($"Discriminator field is the same as the field being deserialized {name}");
throw new OrmShortCircuitInfiniteLoopException("Discriminator field cannot be the same as the field being deserialized");
}
var v = GetFieldValue(info, attr.FieldName, entity, ++iterationCount) ?? throw new Exception($"No value for {attr.FieldName}");
outputType = typeProvider.GetTypeInfo(v);
}

Expand All @@ -302,7 +323,13 @@ public TableEntity ToTableEntity<T>(T typedEntity) where T : EntityBase {
return JsonSerializer.Deserialize(value, outputType, options: _options);
}
}
} catch (Exception ex) {
} catch (Exception ex)
when (ex is not OrmShortCircuitInfiniteLoopException) {
var tags = GenerateTableEntityTags(entity);
tags.AddRange(new (string, string)[] {
("fieldName", fieldName)
});
_logTracer.WithTags(tags).Error($"Unable to get value for property '{name}' (entity field '{fieldName}')");
throw new InvalidOperationException($"Unable to get value for property '{name}' (entity field '{fieldName}')", ex);
}
}
Expand Down Expand Up @@ -361,6 +388,19 @@ public T ToRecord<T>(TableEntity entity) where T : EntityBase {
return Expression.Lambda<Func<T, object?>>(call, paramter).Compile();
}

private static List<(string, string)> GenerateTableEntityTags(TableEntity entity) {
var entityKeys = string.Join(',', entity.Keys);
var partitionKey = entity.ContainsKey(EntityPropertyKind.PartitionKey.ToString()) ? entity.GetString(EntityPropertyKind.PartitionKey.ToString()) : string.Empty;
var rowKey = entity.ContainsKey(EntityPropertyKind.RowKey.ToString()) ? entity.GetString(EntityPropertyKind.RowKey.ToString()) : string.Empty;

return new List<(string, string)> {
("entityKeys", entityKeys),
("partitionKey", partitionKey),
("rowKey", rowKey)
};
}
}

public class OrmShortCircuitInfiniteLoopException : Exception {
public OrmShortCircuitInfiniteLoopException(string message) : base(message) { }
}
2 changes: 1 addition & 1 deletion src/ApiService/IntegrationTests/Fakes/TestContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ namespace IntegrationTests.Fakes;
public sealed class TestContext : IOnefuzzContext {
public TestContext(ILogTracer logTracer, IStorage storage, ICreds creds, string storagePrefix) {
var cache = new MemoryCache(Options.Create(new MemoryCacheOptions()));
EntityConverter = new EntityConverter();
EntityConverter = new EntityConverter(logTracer);
ServiceConfiguration = new TestServiceConfiguration(storagePrefix);
Storage = storage;
Creds = creds;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
using Microsoft.OneFuzz.Service;

using Async = System.Threading.Tasks;
using Console = System.Console;

namespace IntegrationTests.Integration;

Expand Down
3 changes: 2 additions & 1 deletion src/ApiService/Tests/OrmModelsTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using FsCheck.Xunit;
using Microsoft.OneFuzz.Service;
using Microsoft.OneFuzz.Service.OneFuzzLib.Orm;
using Moq;
using Xunit.Abstractions;

namespace Tests {
Expand Down Expand Up @@ -693,7 +694,7 @@ public static bool AreEqual<T>(T r1, T r2) {
}

public class OrmModelsTest {
EntityConverter _converter = new EntityConverter();
EntityConverter _converter = new EntityConverter(new Mock<ILogTracer>().Object);
ITestOutputHelper _output;

public OrmModelsTest(ITestOutputHelper output) {
Expand Down
58 changes: 49 additions & 9 deletions src/ApiService/Tests/OrmTest.cs
Original file line number Diff line number Diff line change
@@ -1,14 +1,23 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text.Json;
using System.Text.Json.Nodes;
using System.Text.Json.Serialization;
using Azure.Data.Tables;
using FluentAssertions;
using Microsoft.OneFuzz.Service;
using Microsoft.OneFuzz.Service.OneFuzzLib.Orm;
using Moq;
using Xunit;

namespace Tests {
public class OrmTest {
private ILogTracer _logTracer;

public OrmTest() {
_logTracer = new LogTracerFactory(new List<ILog> { new Microsoft.OneFuzz.Service.Console() }).CreateLogTracer(Guid.Empty);
}

sealed class TestObject {
public String? TheName { get; set; }
Expand Down Expand Up @@ -55,7 +64,7 @@ sealed record Entity1(
[Fact]
public void TestBothDirections() {
var uriString = new Uri("https://localhost:9090");
var converter = new EntityConverter();
var converter = new EntityConverter(_logTracer);
var entity1 = new Entity1(
Guid.NewGuid(),
"test",
Expand Down Expand Up @@ -105,7 +114,7 @@ public void TestBothDirections() {
[Fact]
public void TestConvertToTableEntity() {
var uriString = new Uri("https://localhost:9090");
var converter = new EntityConverter();
var converter = new EntityConverter(_logTracer);
var entity1 = new Entity1(
Guid.NewGuid(),
"test",
Expand Down Expand Up @@ -154,7 +163,7 @@ public void TestConvertToTableEntity() {

[Fact]
public void TestFromtableEntity() {
var converter = new EntityConverter();
var converter = new EntityConverter(_logTracer);
var tableEntity = new TableEntity(Guid.NewGuid().ToString(), "test") {
{"the_date", DateTimeOffset.UtcNow },
{ "the_number", 1234},
Expand Down Expand Up @@ -249,7 +258,7 @@ [RowKey] string TheName
[Fact]
public void TestIntKey() {
var expected = new Entity2(10, "test");
var converter = new EntityConverter();
var converter = new EntityConverter(_logTracer);
var tableEntity = converter.ToTableEntity(expected);
var actual = converter.ToRecord<Entity2>(tableEntity);

Expand All @@ -267,7 +276,7 @@ Container Container
public void TestContainerSerialization() {
var container = Container.Parse("abc-123");
var expected = new Entity3(123, "abc", container);
var converter = new EntityConverter();
var converter = new EntityConverter(_logTracer);

var tableEntity = converter.ToTableEntity(expected);
var actual = converter.ToRecord<Entity3>(tableEntity);
Expand Down Expand Up @@ -302,7 +311,7 @@ Container Container
public void TestPartitionKeyIsRowKey() {
var container = Container.Parse("abc-123");
var expected = new Entity4(123, "abc", container);
var converter = new EntityConverter();
var converter = new EntityConverter(_logTracer);

var tableEntity = converter.ToTableEntity(expected);
Assert.Equal(expected.Id.ToString(), tableEntity.RowKey);
Expand Down Expand Up @@ -336,7 +345,7 @@ sealed record TestNullField(int? Id, string? Name, TestObject? Obj) : EntityBase
[Fact]
public void TestNullValue() {

var entityConverter = new EntityConverter();
var entityConverter = new EntityConverter(_logTracer);
var tableEntity = entityConverter.ToTableEntity(new TestNullField(null, null, null));

Assert.Null(tableEntity["id"]);
Expand Down Expand Up @@ -367,7 +376,7 @@ sealed record TestEntity3(DoNotRename Enum, DoNotRenameFlag flag) : EntityBase()
[Fact]
public void TestSkipRename() {

var entityConverter = new EntityConverter();
var entityConverter = new EntityConverter(_logTracer);

var expected = new TestEntity3(DoNotRename.TEST3, DoNotRenameFlag.Test_2 | DoNotRenameFlag.test1);
var tableEntity = entityConverter.ToTableEntity(expected);
Expand All @@ -390,7 +399,7 @@ sealed record TestIinit([DefaultValue(InitMethod.DefaultConstructor)] TestClass

[Fact]
public void TestInitValue() {
var entityConverter = new EntityConverter();
var entityConverter = new EntityConverter(_logTracer);
var tableEntity = new TableEntity();
var actual = entityConverter.ToRecord<TestIinit>(tableEntity);

Expand All @@ -410,5 +419,36 @@ public void TestKeyGetters() {
Assert.Equal(test.PartitionKey, actualPartitionKey);
Assert.Equal(test.RowKey, actualRowKey);
}

sealed record NestedEntity(
[PartitionKey] int Id,
[RowKey] string TheName,
[property: TypeDiscrimnatorAttribute("EventType", typeof(EventTypeProvider))]
[property: JsonConverter(typeof(BaseEventConverter))]
Nested? EventType
) : EntityBase();

#pragma warning disable CS0169
public record Nested(
bool? B,
Nested? EventType
) : BaseEvent();
#pragma warning restore CS0169

[Fact]
public void TestDeeplyNestedObjects() {
var converter = new EntityConverter(_logTracer);
var deeplyNestedJson = $"{{{string.Concat(Enumerable.Repeat("\"EventType\": {", 3))}{new String('}', 3)}}}"; // {{{...}}}
var nestedEntity = new NestedEntity(
Id: 123,
TheName: "abc",
EventType: JsonSerializer.Deserialize<Nested>(deeplyNestedJson, new JsonSerializerOptions())
);

var tableEntity = converter.ToTableEntity(nestedEntity);
var toRecord = () => converter.ToRecord<NestedEntity>(tableEntity);

_ = toRecord.Should().Throw<Exception>().And.InnerException!.Should().BeOfType<OrmShortCircuitInfiniteLoopException>();
}
}
}