Skip to content

Commit

Permalink
Allow array types
Browse files Browse the repository at this point in the history
  • Loading branch information
ajcvickers committed Aug 11, 2024
1 parent 2b43dbe commit 8efd22d
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 47 deletions.
13 changes: 4 additions & 9 deletions src/EFCore.Cosmos/Extensions/CosmosPropertyBuilderExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -239,13 +239,8 @@ public static PropertyBuilder<TProperty> IsETagConcurrency<TProperty>(

[Experimental(EFDiagnostics.CosmosVectorSearchExperimental)]
private static CosmosVectorType CreateVectorType(DistanceFunction distanceFunction, int dimensions)
{
if (!Enum.IsDefined(distanceFunction))
{
throw new ArgumentException(CoreStrings.InvalidEnumValue(distanceFunction, nameof(distanceFunction), typeof(DistanceFunction)));
}

var vectorType = new CosmosVectorType(distanceFunction, dimensions);
return vectorType;
}
=> Enum.IsDefined(distanceFunction)
? new CosmosVectorType(distanceFunction, dimensions)
: throw new ArgumentException(
CoreStrings.InvalidEnumValue(distanceFunction, nameof(distanceFunction), typeof(DistanceFunction)));
}
4 changes: 3 additions & 1 deletion src/EFCore.Cosmos/Metadata/Internal/CosmosVectorType.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ public sealed record class CosmosVectorType(DistanceFunction DistanceFunction, i
/// </summary>
public static VectorDataType CreateDefaultVectorDataType(Type clrType)
{
var elementType = clrType.TryGetElementType(typeof(ReadOnlyMemory<>))?.UnwrapNullableType();
var elementType = clrType.TryGetElementType(typeof(ReadOnlyMemory<>))?.UnwrapNullableType()
?? clrType.TryGetElementType(typeof(IEnumerable<>))?.UnwrapNullableType();

return elementType == typeof(sbyte)
? VectorDataType.Int8
: elementType == typeof(byte)
Expand Down
2 changes: 1 addition & 1 deletion src/EFCore.Cosmos/Properties/CosmosStrings.Designer.cs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion src/EFCore.Cosmos/Properties/CosmosStrings.resx
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@
<value>The type '{givenType}' cannot be mapped as a dictionary because it does not implement '{dictionaryType}'.</value>
</data>
<data name="BadVectorDataType" xml:space="preserve">
<value>The type '{clrType}' is being used as a vector, but the vector data type cannot be inferred. Only 'ReadOnlyMemory&lt;byte&gt;, ReadOnlyMemory&lt;sbyte&gt;, and ReadOnlyMemory&lt;float&gt; are supported.</value>
<value>The type '{clrType}' is being used as a vector, but the vector data type cannot be inferred. Only 'ReadOnlyMemory&lt;byte&gt;, ReadOnlyMemory&lt;sbyte&gt;, ReadOnlyMemory&lt;float&gt;, byte[], sbyte[], and float[] are supported.</value>
</data>
<data name="CanConnectNotSupported" xml:space="preserve">
<value>The Cosmos database does not support 'CanConnect' or 'CanConnectAsync'.</value>
Expand Down
14 changes: 5 additions & 9 deletions src/EFCore.Cosmos/Storage/Internal/CosmosTypeMappingSource.cs
Original file line number Diff line number Diff line change
Expand Up @@ -47,19 +47,15 @@ public CosmosTypeMappingSource(TypeMappingSourceDependencies dependencies)
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public override CoreTypeMapping? FindMapping(IProperty property)
{
// A provider should typically not override this because using the property directly causes problems with Migrations where
// the property does not exist. However, since the Cosmos provider doesn't have Migrations, it should be okay to use the property
// directly.
var mapping = (CosmosTypeMapping?)base.FindMapping(property);
if (mapping is not null
&& property.FindAnnotation(CosmosAnnotationNames.VectorType)?.Value is CosmosVectorType vectorType)
=> base.FindMapping(property) switch
{
mapping = new CosmosVectorTypeMapping(mapping, vectorType);
}

return mapping;
}
CosmosTypeMapping mapping when property.FindAnnotation(CosmosAnnotationNames.VectorType)?.Value is CosmosVectorType vectorType
=> new CosmosVectorTypeMapping(mapping, vectorType),
var other => other
};

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
Expand Down
135 changes: 109 additions & 26 deletions test/EFCore.Cosmos.FunctionalTests/VectorSearchCosmosTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,47 @@ FROM root c
""");
}

[ConditionalFact]
public virtual async Task Query_for_vector_distance_bytes_array()
{
await using var context = CreateContext();
var inputVector = new byte[] { 2, 1, 4, 3, 5, 2, 5, 7, 3, 1 };

// See Issue #34402
await Assert.ThrowsAsync<InvalidOperationException>(
() => context.Set<Book>().Select(e => EF.Functions.VectorDistance(e.BytesArray, inputVector)).ToListAsync());

// Assert.Equal(3, booksFromStore.Count);
// Assert.All(booksFromStore, s => Assert.NotEqual(0.0, s));

AssertSql(
"""
SELECT VALUE c["BytesArray"]
FROM root c
""");
}

[ConditionalFact]
public virtual async Task Query_for_vector_distance_singles_array()
{
await using var context = CreateContext();
var inputVector = new[] { 0.33f, -0.52f, 0.45f, -0.67f, 0.89f, -0.34f, 0.86f, -0.78f, 0.86f, -0.78f };

// See Issue #34402
await Assert.ThrowsAsync<InvalidOperationException>(
() => context.Set<Book>()
.Select(e => EF.Functions.VectorDistance(e.SinglesArray, inputVector, false, DistanceFunction.DotProduct)).ToListAsync());

// Assert.Equal(3, booksFromStore.Count);
// Assert.All(booksFromStore, s => Assert.NotEqual(0.0, s));

AssertSql(
"""
SELECT VALUE c["SinglesArray"]
FROM root c
""");
}

[ConditionalFact]
public virtual async Task Vector_distance_sbytes_in_OrderBy()
{
Expand Down Expand Up @@ -128,8 +169,8 @@ public virtual async Task Vector_distance_bytes_in_OrderBy()
.ToListAsync();

Assert.Equal(3, booksFromStore.Count);
AssertSql(
"""
AssertSql(
"""
@__p_1='[2,1,4,6,5,2,5,7,3,1]'

SELECT VALUE c
Expand Down Expand Up @@ -160,6 +201,37 @@ ORDER BY VectorDistance(c["Singles"], @__p_1, false, {'distanceFunction':'cosine
""");
}

[ConditionalFact]
public virtual async Task Vector_distance_bytes_array_in_OrderBy()
{
await using var context = CreateContext();
var inputVector = new byte[] { 2, 1, 4, 6, 5, 2, 5, 7, 3, 1 };

// See Issue #34402
await Assert.ThrowsAsync<InvalidOperationException>(
() => context.Set<Book>().OrderBy(e => EF.Functions.VectorDistance(e.BytesArray, inputVector)).ToListAsync());

// Assert.Equal(3, booksFromStore.Count);

AssertSql(
);
}

[ConditionalFact]
public virtual async Task Vector_distance_singles_array_in_OrderBy()
{
await using var context = CreateContext();
var inputVector = new[] { 0.33f, -0.52f, 0.45f, -0.67f, 0.89f, -0.34f, 0.86f, -0.78f };

// See Issue #34402
await Assert.ThrowsAsync<InvalidOperationException>(
() => context.Set<Book>().OrderBy(e => EF.Functions.VectorDistance(e.SinglesArray, inputVector)).ToListAsync());

// Assert.Equal(3, booksFromStore.Count);

AssertSql();
}

[ConditionalFact]
public virtual async Task VectorDistance_throws_when_used_on_non_vector()
{
Expand Down Expand Up @@ -239,6 +311,10 @@ private class Book

public ReadOnlyMemory<float> Singles { get; set; } = null!;

public byte[] BytesArray { get; set; } = null!;

public float[] SinglesArray { get; set; } = null!;

public Owned1 OwnedReference { get; set; } = null!;
public List<Owned1> OwnedCollection { get; set; } = null!;
}
Expand Down Expand Up @@ -279,10 +355,14 @@ protected override void OnModelCreating(ModelBuilder modelBuilder, DbContext con
b.HasIndex(e => e.Bytes).ForVectors(VectorIndexType.Flat);
b.HasIndex(e => e.SBytes).ForVectors(VectorIndexType.Flat);
b.HasIndex(e => e.Singles).ForVectors(VectorIndexType.Flat);
b.HasIndex(e => e.BytesArray).ForVectors(VectorIndexType.Flat);
b.HasIndex(e => e.SinglesArray).ForVectors(VectorIndexType.Flat);
b.Property(e => e.Bytes).IsVector(DistanceFunction.Cosine, 10);
b.Property(e => e.SBytes).IsVector(DistanceFunction.DotProduct, 10);
b.Property(e => e.Singles).IsVector(DistanceFunction.Cosine, 10);
b.Property(e => e.BytesArray).IsVector(DistanceFunction.Cosine, 10);
b.Property(e => e.SinglesArray).IsVector(DistanceFunction.Cosine, 10);
});

protected override Task SeedAsync(PoolableDbContext context)
Expand All @@ -293,17 +373,18 @@ protected override Task SeedAsync(PoolableDbContext context)
Author = "Jon P Smith",
Title = "Entity Framework Core in Action",
Isbn = new ReadOnlyMemory<byte>("978-1617298363"u8.ToArray()),
Bytes = new([2, 1, 4, 3, 5, 2, 5, 7, 3, 1]),
SBytes = new([2, -1, 4, 3, 5, -2, 5, -7, 3, 1]),
Singles = new([ 0.33f, -0.52f, 0.45f, -0.67f, 0.89f, -0.34f, 0.86f, -0.78f, 0.86f, -0.78f ]),

OwnedReference = new()
Bytes = new ReadOnlyMemory<byte>([2, 1, 4, 3, 5, 2, 5, 7, 3, 1]),
SBytes = new ReadOnlyMemory<sbyte>([2, -1, 4, 3, 5, -2, 5, -7, 3, 1]),
Singles = new ReadOnlyMemory<float>([0.33f, -0.52f, 0.45f, -0.67f, 0.89f, -0.34f, 0.86f, -0.78f, 0.86f, -0.78f]),
BytesArray = [2, 1, 4, 3, 5, 2, 5, 7, 3, 1],
SinglesArray = [0.33f, -0.52f, 0.45f, -0.67f, 0.89f, -0.34f, 0.86f, -0.78f, 0.86f, -0.78f],
OwnedReference = new Owned1
{
Prop = 7,
NestedOwned = new() { Prop = "7" },
NestedOwnedCollection = new() { new() { Prop = "71" }, new() { Prop = "72" } }
NestedOwned = new Owned2 { Prop = "7" },
NestedOwnedCollection = new List<Owned2> { new() { Prop = "71" }, new() { Prop = "72" } }
},
OwnedCollection = new() { new() { Prop = 71 }, new() { Prop = 72 } }
OwnedCollection = new List<Owned1> { new Owned1 { Prop = 71 }, new Owned1 { Prop = 72 } }
};

var book2 = new Book
Expand All @@ -312,17 +393,18 @@ protected override Task SeedAsync(PoolableDbContext context)
Author = "Julie Lerman",
Title = "Programming Entity Framework: DbContext",
Isbn = new ReadOnlyMemory<byte>("978-1449312961"u8.ToArray()),
Bytes = new([2, 1, 4, 3, 5, 2, 5, 7, 3, 1]),
SBytes = new([2, -1, 4, 3, 5, -2, 5, -7, 3, 1]),
Singles = new([ 0.33f, -0.52f, 0.45f, -0.67f, 0.89f, -0.34f, 0.86f, -0.78f, 0.86f, -0.78f ]),

OwnedReference = new()
Bytes = new ReadOnlyMemory<byte>([2, 1, 4, 3, 5, 2, 5, 7, 3, 1]),
SBytes = new ReadOnlyMemory<sbyte>([2, -1, 4, 3, 5, -2, 5, -7, 3, 1]),
Singles = new ReadOnlyMemory<float>([0.33f, -0.52f, 0.45f, -0.67f, 0.89f, -0.34f, 0.86f, -0.78f, 0.86f, -0.78f]),
BytesArray = [2, 1, 4, 3, 5, 2, 5, 7, 3, 1],
SinglesArray = [0.33f, -0.52f, 0.45f, -0.67f, 0.89f, -0.34f, 0.86f, -0.78f, 0.86f, -0.78f],
OwnedReference = new Owned1
{
Prop = 7,
NestedOwned = new() { Prop = "7" },
NestedOwnedCollection = new() { new() { Prop = "71" }, new() { Prop = "72" } }
NestedOwned = new Owned2 { Prop = "7" },
NestedOwnedCollection = new List<Owned2> { new() { Prop = "71" }, new() { Prop = "72" } }
},
OwnedCollection = new() { new() { Prop = 71 }, new() { Prop = 72 } }
OwnedCollection = new List<Owned1> { new Owned1 { Prop = 71 }, new Owned1 { Prop = 72 } }
};

var book3 = new Book
Expand All @@ -331,17 +413,18 @@ protected override Task SeedAsync(PoolableDbContext context)
Author = "Julie Lerman",
Title = "Programming Entity Framework",
Isbn = new ReadOnlyMemory<byte>("978-0596807269"u8.ToArray()),
Bytes = new([2, 1, 4, 3, 5, 2, 5, 7, 3, 1]),
SBytes = new([2, -1, 4, 3, 5, -2, 5, -7, 3, 1]),
Singles = new([ 0.33f, -0.52f, 0.45f, -0.67f, 0.89f, -0.34f, 0.86f, -0.78f, 0.86f, -0.78f ]),

OwnedReference = new()
Bytes = new ReadOnlyMemory<byte>([2, 1, 4, 3, 5, 2, 5, 7, 3, 1]),
SBytes = new ReadOnlyMemory<sbyte>([2, -1, 4, 3, 5, -2, 5, -7, 3, 1]),
Singles = new ReadOnlyMemory<float>([0.33f, -0.52f, 0.45f, -0.67f, 0.89f, -0.34f, 0.86f, -0.78f, 0.86f, -0.78f]),
BytesArray = [2, 1, 4, 3, 5, 2, 5, 7, 3, 1],
SinglesArray = [0.33f, -0.52f, 0.45f, -0.67f, 0.89f, -0.34f, 0.86f, -0.78f, 0.86f, -0.78f],
OwnedReference = new Owned1
{
Prop = 7,
NestedOwned = new() { Prop = "7" },
NestedOwnedCollection = new() { new() { Prop = "71" }, new() { Prop = "72" } }
NestedOwned = new Owned2 { Prop = "7" },
NestedOwnedCollection = new List<Owned2> { new() { Prop = "71" }, new() { Prop = "72" } }
},
OwnedCollection = new() { new() { Prop = 71 }, new() { Prop = 72 } }
OwnedCollection = new List<Owned1> { new Owned1 { Prop = 71 }, new Owned1 { Prop = 72 } }
};

context.AddRange(book1, book2, book3);
Expand Down

0 comments on commit 8efd22d

Please sign in to comment.