Skip to content
This repository has been archived by the owner on Nov 1, 2023. It is now read-only.

Commit

Permalink
Cap recursion in ORM (#2992)
Browse files Browse the repository at this point in the history
* Add new command

* Update remaining jinja templates and references to use scriban

* almost done

* making progress

* Add 2 cases to stop OOM exceptions in the future

* More logs

* PR feedback

* Remove unnecessary changes

* 🧹

* PR comments
  • Loading branch information
tevoinea authored Apr 12, 2023
1 parent ace0ccc commit 41fa0a7
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 9 deletions.
63 changes: 55 additions & 8 deletions src/ApiService/ApiService/onefuzzlib/orm/EntityConverter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ public override string ConvertName(string name) {
}
public class EntityConverter {

private const int MAX_DESERIALIZATION_RECURSION_DEPTH = 100;
private readonly ConcurrentDictionary<Type, EntityInfo> _cache;
private static readonly JsonSerializerOptions _options = new() {
PropertyNamingPolicy = new OnefuzzNamingPolicy(),
Expand Down Expand Up @@ -124,8 +125,8 @@ public static JsonSerializerOptions GetJsonSerializerOptions() {
}

private static IEnumerable<EntityProperty> GetEntityProperties<T>(ParameterInfo parameterInfo) {
var name = parameterInfo.Name.EnsureNotNull($"Invalid paramter {parameterInfo}");
var parameterType = parameterInfo.ParameterType.EnsureNotNull($"Invalid paramter {parameterInfo}");
var name = parameterInfo.Name.EnsureNotNull($"Invalid parameter {parameterInfo}");
var parameterType = parameterInfo.ParameterType.EnsureNotNull($"Invalid parameter {parameterInfo}");
var isRowkey = parameterInfo.GetCustomAttribute(typeof(RowKeyAttribute)) != null;
var isPartitionkey = parameterInfo.GetCustomAttribute(typeof(PartitionKeyAttribute)) != null;

Expand All @@ -135,7 +136,7 @@ private static IEnumerable<EntityProperty> GetEntityProperties<T>(ParameterInfo

(TypeDiscrimnatorAttribute, ITypeProvider)? discriminator = null;
if (discriminatorAttribute != null) {
var t = (ITypeProvider)(Activator.CreateInstance(discriminatorAttribute.ConverterType) ?? throw new Exception("unable to retrive the type provider"));
var t = (ITypeProvider)(Activator.CreateInstance(discriminatorAttribute.ConverterType) ?? throw new Exception("unable to retrieve the type provider"));
discriminator = (discriminatorAttribute, t);
}

Expand Down Expand Up @@ -222,7 +223,7 @@ public TableEntity ToTableEntity<T>(T typedEntity) where T : EntityBase {
}


private object? GetFieldValue(EntityInfo info, string name, TableEntity entity) {
private object? GetFieldValue(EntityInfo info, string name, TableEntity entity, int iterationCount) {
var ef = info.properties[name].First();
if (ef.kind == EntityPropertyKind.PartitionKey || ef.kind == EntityPropertyKind.RowKey) {
// partition & row keys must always be strings
Expand Down Expand Up @@ -285,7 +286,23 @@ public TableEntity ToTableEntity<T>(T typedEntity) where T : EntityBase {
var outputType = ef.type;
if (ef.discriminator != null) {
var (attr, typeProvider) = ef.discriminator.Value;
var v = GetFieldValue(info, attr.FieldName, entity) ?? throw new Exception($"No value for {attr.FieldName}");
if (iterationCount > MAX_DESERIALIZATION_RECURSION_DEPTH) {
var tags = GenerateTableEntityTags(entity);
tags.AddRange(new (string, string)[] {
("outputType", outputType?.Name ?? string.Empty),
("fieldName", fieldName)
});
throw new OrmMaxRecursionDepthReachedException($"MAX_DESERIALIZATION_RECURSION_DEPTH reached. Too many iterations deserializing {info.type}. {PrintTags(tags)}");
}
if (attr.FieldName == name) {
var tags = GenerateTableEntityTags(entity);
tags.AddRange(new (string, string)[] {
("outputType", outputType?.Name ?? string.Empty),
("fieldName", fieldName)
});
throw new OrmInvalidDiscriminatorFieldException($"Discriminator field is the same as the field being deserialized {name}. {PrintTags(tags)}");
}
var v = GetFieldValue(info, attr.FieldName, entity, ++iterationCount) ?? throw new Exception($"No value for {attr.FieldName}");
outputType = typeProvider.GetTypeInfo(v);
}

Expand All @@ -302,8 +319,13 @@ public TableEntity ToTableEntity<T>(T typedEntity) where T : EntityBase {
return JsonSerializer.Deserialize(value, outputType, options: _options);
}
}
} catch (Exception ex) {
throw new InvalidOperationException($"Unable to get value for property '{name}' (entity field '{fieldName}')", ex);
} catch (Exception ex)
when (ex is not OrmException) {
var tags = GenerateTableEntityTags(entity);
tags.AddRange(new (string, string)[] {
("fieldName", fieldName)
});
throw new InvalidOperationException($"Unable to get value for property '{name}' (entity field '{fieldName}'). {PrintTags(tags)}", ex);
}
}

Expand All @@ -313,7 +335,7 @@ public T ToRecord<T>(TableEntity entity) where T : EntityBase {

object?[] parameters;
try {
parameters = entityInfo.properties.Select(grouping => GetFieldValue(entityInfo, grouping.Key, entity)).ToArray();
parameters = entityInfo.properties.Select(grouping => GetFieldValue(entityInfo, grouping.Key, entity, 0)).ToArray();
} catch (Exception ex) {
throw new InvalidOperationException($"Unable to extract properties from TableEntity for {typeof(T)}", ex);
}
Expand Down Expand Up @@ -361,6 +383,31 @@ public T ToRecord<T>(TableEntity entity) where T : EntityBase {
return Expression.Lambda<Func<T, object?>>(call, paramter).Compile();
}

private static List<(string, string)> GenerateTableEntityTags(TableEntity entity) {
var entityKeys = string.Join(',', entity.Keys);
var partitionKey = entity.ContainsKey(EntityPropertyKind.PartitionKey.ToString()) ? entity.GetString(EntityPropertyKind.PartitionKey.ToString()) : string.Empty;
var rowKey = entity.ContainsKey(EntityPropertyKind.RowKey.ToString()) ? entity.GetString(EntityPropertyKind.RowKey.ToString()) : string.Empty;

return new List<(string, string)> {
("entityKeys", entityKeys),
("partitionKey", partitionKey),
("rowKey", rowKey)
};
}

private static string PrintTags(List<(string, string)>? tags) {
return tags != null ? string.Join(", ", tags.Select(x => $"{x.Item1}={x.Item2}")) : string.Empty;
}
}

public class OrmInvalidDiscriminatorFieldException : OrmException {
public OrmInvalidDiscriminatorFieldException(string message) : base(message) { }
}

public class OrmMaxRecursionDepthReachedException : OrmException {
public OrmMaxRecursionDepthReachedException(string message) : base(message) { }
}

public class OrmException : Exception {
public OrmException(string message) : base(message) { }
}
35 changes: 34 additions & 1 deletion src/ApiService/Tests/OrmTest.cs
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
using System;
using System.Linq;
using System.Text.Json;
using System.Text.Json.Nodes;
using System.Text.Json.Serialization;
using Azure.Data.Tables;
using FluentAssertions;
using Microsoft.OneFuzz.Service;
using Microsoft.OneFuzz.Service.OneFuzzLib.Orm;
using Moq;
using Xunit;

namespace Tests {
public class OrmTest {

sealed class TestObject {
public String? TheName { get; set; }
public TestEnum TheEnum { get; set; }
Expand Down Expand Up @@ -410,5 +412,36 @@ public void TestKeyGetters() {
Assert.Equal(test.PartitionKey, actualPartitionKey);
Assert.Equal(test.RowKey, actualRowKey);
}

sealed record NestedEntity(
[PartitionKey] int Id,
[RowKey] string TheName,
[property: TypeDiscrimnatorAttribute("EventType", typeof(EventTypeProvider))]
[property: JsonConverter(typeof(BaseEventConverter))]
Nested? EventType
) : EntityBase();

#pragma warning disable CS0169
public record Nested(
bool? B,
Nested? EventType
) : BaseEvent();
#pragma warning restore CS0169

[Fact]
public void TestDeeplyNestedObjects() {
var converter = new EntityConverter();
var deeplyNestedJson = $"{{{string.Concat(Enumerable.Repeat("\"EventType\": {", 3))}{new String('}', 3)}}}"; // {{{...}}}
var nestedEntity = new NestedEntity(
Id: 123,
TheName: "abc",
EventType: JsonSerializer.Deserialize<Nested>(deeplyNestedJson, new JsonSerializerOptions())
);

var tableEntity = converter.ToTableEntity(nestedEntity);
var toRecord = () => converter.ToRecord<NestedEntity>(tableEntity);

_ = toRecord.Should().Throw<Exception>().And.InnerException!.Should().BeOfType<OrmInvalidDiscriminatorFieldException>();
}
}
}

0 comments on commit 41fa0a7

Please sign in to comment.