diff --git a/src/GraphQL.EntityFramework/Where/ArgumentReader.cs b/src/GraphQL.EntityFramework/Where/ArgumentReader.cs index 1b75e792..c6db1df6 100644 --- a/src/GraphQL.EntityFramework/Where/ArgumentReader.cs +++ b/src/GraphQL.EntityFramework/Where/ArgumentReader.cs @@ -9,17 +9,8 @@ public static bool TryReadWhere(IResolveFieldContext context, out IReadOnlyColle public static IReadOnlyCollection ReadOrderBy(IResolveFieldContext context) => ReadList(context, "orderBy"); - public static bool TryReadIds(IResolveFieldContext context, [NotNullWhen(true)] out string[]? idValues) + public static bool TryReadIds(IResolveFieldContext context, [NotNullWhen(true)] out object[]? idValues) { - static string ArgumentToExpression(object argument) => - argument switch - { - long l => l.ToString(CultureInfo.InvariantCulture), - int i => i.ToString(CultureInfo.InvariantCulture), - string s => s, - _ => throw new($"TryReadId got an 'id' argument of type '{argument.GetType().FullName}' which is not supported.") - }; - var arguments = context.Arguments; if (arguments == null) { @@ -43,7 +34,7 @@ static string ArgumentToExpression(object argument) => return false; } - var expressions = new List(); + var expressions = new List(); if (id.Source != ArgumentSource.FieldDefault) { @@ -53,7 +44,7 @@ static string ArgumentToExpression(object argument) => throw new("Null 'id' is not supported."); } - expressions.Add(ArgumentToExpression(idValue)); + expressions.Add(idValue); } if (ids.Source != ArgumentSource.FieldDefault) @@ -63,7 +54,7 @@ static string ArgumentToExpression(object argument) => throw new($"TryReadIds got an 'ids' argument of type '{ids.Value!.GetType().FullName}' which is not supported."); } - expressions.AddRange(objCollection.Select(ArgumentToExpression)); + expressions.AddRange(objCollection); } idValues = expressions.ToArray(); diff --git a/src/GraphQL.EntityFramework/Where/ExpressionBuilder.cs b/src/GraphQL.EntityFramework/Where/ExpressionBuilder.cs index 20da473a..b87a63f8 100644 --- a/src/GraphQL.EntityFramework/Where/ExpressionBuilder.cs +++ b/src/GraphQL.EntityFramework/Where/ExpressionBuilder.cs @@ -39,8 +39,18 @@ static Expression MakePredicateBody(IReadOnlyCollection wheres) // Otherwise handle single expressions else { - // Get the predicate body for the single expression - nextExpression = MakePredicateBody(where.Path, where.Comparison, where.Value, where.Negate); + + if (where.Value != null) + { + var property = PropertyCache.GetProperty(where.Path); + var values = TypeConverter.ConvertStringsToList(where.Value, property.PropertyType); + // Get the predicate body for the single expression + nextExpression = MakePredicateBody(where.Path, where.Comparison, values.ToArray(), where.Negate); + } + else + { + nextExpression = MakePredicateBody(where.Path, where.Comparison, null, where.Negate); + } } // If this is the first where processed @@ -65,7 +75,7 @@ static Expression MakePredicateBody(IReadOnlyCollection wheres) /// /// Create a single predicate for the single set of supplied conditional arguments /// - public static Expression> BuildPredicate(string path, Comparison comparison, string?[]? values, bool negate = false) + public static Expression> BuildPredicate(string path, Comparison comparison, object?[]? values, bool negate = false) { var expressionBody = MakePredicateBody(path, comparison, values, negate); var param = PropertyCache.SourceParameter; @@ -73,7 +83,7 @@ public static Expression> BuildPredicate(string path, Comparison c return Expression.Lambda>(expressionBody, param); } - static Expression MakePredicateBody(string path, Comparison comparison, string?[]? values, bool negate) + static Expression MakePredicateBody(string path, Comparison comparison, object?[]? values, bool negate) { try { @@ -106,7 +116,7 @@ static Expression MakePredicateBody(string path, Comparison comparison, string?[ } } - static Expression ProcessList(string path, Comparison comparison, string?[]? values) + static Expression ProcessList(string path, Comparison comparison, object?[]? values) { // Get the path pertaining to individual list items var listPath = ListPropertyRegex().Match(path).Groups[1].Value; @@ -152,7 +162,7 @@ static Expression ProcessList(string path, Comparison comparison, string?[]? val return Expression.Call(anyInfo, property.Left, subPredicate); } - static Expression GetExpression(string path, Comparison comparison, string?[]? values) + static Expression GetExpression(string path, Comparison comparison, object?[]? values) { var property = PropertyCache.GetProperty(path); Expression expressionBody; @@ -193,8 +203,7 @@ static Expression GetExpression(string path, Comparison comparison, string?[]? v default: WhereValidator.ValidateSingleObject(property.PropertyType, comparison); var value = values?.Single(); - var valueObject = TypeConverter.ConvertStringToType(value, property.PropertyType); - expressionBody = MakeSingleObjectComparison(comparison, valueObject, property); + expressionBody = MakeSingleObjectComparison(comparison, value, property); break; } } @@ -202,12 +211,10 @@ static Expression GetExpression(string path, Comparison comparison, string?[]? v return expressionBody; } - static Expression MakeObjectListInComparision(string[] values, Property property) + static Expression MakeObjectListInComparision(object[] values, Property property) { - // Attempt to convert the string values to the object type - var objects = TypeConverter.ConvertStringsToList(values, property.Info); // Make the object values a constant expression - var constant = Expression.Constant(objects); + var constant = Expression.Constant(values); // Build and return the expression body if (property.ListContains is null) { @@ -217,7 +224,7 @@ static Expression MakeObjectListInComparision(string[] values, Property prope return Expression.Call(constant, property.ListContains, property.Left); } - static Expression MakeStringListInComparison(string[] values, Property property) + static Expression MakeStringListInComparison(object[] values, Property property) { var equalsBody = Expression.Call(null, ReflectionCache.StringEqual, ExpressionCache.StringParam, property.Left); @@ -228,7 +235,7 @@ static Expression MakeStringListInComparison(string[] values, Property proper return Expression.Call(null, ReflectionCache.StringAny, Expression.Constant(values), itemEvaluate); } - static Expression MakeSingleStringComparison(Comparison comparison, string? value, Property property) + static Expression MakeSingleStringComparison(Comparison comparison, object? value, Property property) { var left = property.Left; diff --git a/src/GraphQL.EntityFramework/Where/TypeConverter.cs b/src/GraphQL.EntityFramework/Where/TypeConverter.cs index e1b76ac4..2aa38364 100644 --- a/src/GraphQL.EntityFramework/Where/TypeConverter.cs +++ b/src/GraphQL.EntityFramework/Where/TypeConverter.cs @@ -1,6 +1,6 @@ static class TypeConverter { - public static IList ConvertStringsToList(string?[] values, MemberInfo property) + public static List ConvertStringsToList(string?[] values, MemberInfo property) { var hash = new HashSet(); var duplicates = values.Where(_ => !hash.Add(_)).ToArray(); @@ -37,126 +37,126 @@ static bool ParseBoolean(string value) => _ => bool.Parse(value) }; - static IList ConvertStringsToListInternal(IEnumerable values, Type type) + static List ConvertStringsToListInternal(IEnumerable values, Type type) { if (type == typeof(Guid)) { - return values.Select(Guid.Parse).ToList(); + return values.Select(s => (object?)Guid.Parse(s)).ToList(); } if (type == typeof(Guid?)) { - return values.Select(_ => (Guid?)new Guid(_)).ToList(); + return values.Select(_ => (object?)new Guid(_)).ToList(); } if (type == typeof(bool)) { - return values.Select(ParseBoolean).ToList(); + return values.Select(s => (object?)ParseBoolean(s)).ToList(); } if (type == typeof(bool?)) { - return values.Select(_ => (bool?)ParseBoolean(_)).ToList(); + return values.Select(_ => (object?)ParseBoolean(_)).ToList(); } if (type == typeof(int)) { - return values.Select(int.Parse).ToList(); + return values.Select(s => (object?)int.Parse(s)).ToList(); } if (type == typeof(int?)) { - return values.Select(_ => (int?)int.Parse(_)).ToList(); + return values.Select(_ => (object?)int.Parse(_)).ToList(); } if (type == typeof(short)) { - return values.Select(short.Parse).ToList(); + return values.Select(s => (object?)short.Parse(s)).ToList(); } if (type == typeof(short?)) { - return values.Select(_ => (short?)short.Parse(_)).ToList(); + return values.Select(_ => (object?)short.Parse(_)).ToList(); } if (type == typeof(long)) { - return values.Select(long.Parse).ToList(); + return values.Select(s => (object?)long.Parse(s)).ToList(); } if (type == typeof(long?)) { - return values.Select(_ => (long?)long.Parse(_)).ToList(); + return values.Select(_ => (object?)long.Parse(_)).ToList(); } if (type == typeof(uint)) { - return values.Select(uint.Parse).ToList(); + return values.Select(s => (object?)uint.Parse(s)).ToList(); } if (type == typeof(uint?)) { - return values.Select(_ => (uint?)uint.Parse(_)).ToList(); + return values.Select(_ => (object?)uint.Parse(_)).ToList(); } if (type == typeof(ushort)) { - return values.Select(ushort.Parse).ToList(); + return values.Select(s =>(object?) ushort.Parse(s)).ToList(); } if (type == typeof(ushort?)) { - return values.Select(_ => (ushort?)ushort.Parse(_)).ToList(); + return values.Select(_ => (object?)ushort.Parse(_)).ToList(); } if (type == typeof(ulong)) { - return values.Select(ulong.Parse).ToList(); + return values.Select(s => (object?)ulong.Parse(s)).ToList(); } if (type == typeof(ulong?)) { - return values.Select(_ => (ulong?)ulong.Parse(_)).ToList(); + return values.Select(_ => (object?)ulong.Parse(_)).ToList(); } if (type == typeof(DateTime)) { - return values.Select(DateTime.Parse).ToList(); + return values.Select(s => (object?)DateTime.Parse(s)).ToList(); } if (type == typeof(DateTime?)) { - return values.Select(_ => (DateTime?)DateTime.Parse(_)).ToList(); + return values.Select(_ => (object?)DateTime.Parse(_)).ToList(); } if (type == typeof(Time)) { - return values.Select(Time.Parse).ToList(); + return values.Select(s => (object?)Time.Parse(s)).ToList(); } if (type == typeof(Time?)) { - return values.Select(_ => (Time?)Time.Parse(_)).ToList(); + return values.Select(_ => (object?)Time.Parse(_)).ToList(); } if (type == typeof(Date)) { - return values.Select(_ => Date.ParseExact(_, "yyyy-MM-dd")).ToList(); + return values.Select(_ => (object?)Date.ParseExact(_, "yyyy-MM-dd")).ToList(); } if (type == typeof(Date?)) { - return values.Select(_ => (Date?)Date.ParseExact(_, "yyyy-MM-dd")).ToList(); + return values.Select(_ => (object?)Date.ParseExact(_, "yyyy-MM-dd")).ToList(); } if (type == typeof(DateTimeOffset)) { - return values.Select(DateTimeOffset.Parse).ToList(); + return values.Select(s => (object?)DateTimeOffset.Parse(s)).ToList(); } if (type == typeof(DateTimeOffset?)) { - return values.Select(_ => (DateTimeOffset?)DateTimeOffset.Parse(_)).ToList(); + return values.Select(_ => (object?)DateTimeOffset.Parse(_)).ToList(); } if (type.IsEnum) { var getList = enumListMethod.MakeGenericMethod(type); - return (IList)getList.Invoke(null, [values])!; + return (List)getList.Invoke(null, [values])!; } if (type.TryGetEnumType(out var enumType)) { var getList = nullableEnumListMethod.MakeGenericMethod(enumType); - return (IList)getList.Invoke(null, [values])!; + return (List)getList.Invoke(null, [values])!; } throw new($"Could not convert strings to {type.FullName}.");