Skip to content

Commit

Permalink
Adds custom key support in merge operation
Browse files Browse the repository at this point in the history
  • Loading branch information
redbaty committed Nov 6, 2023
1 parent ef3d400 commit 9093148
Show file tree
Hide file tree
Showing 15 changed files with 210 additions and 33 deletions.
4 changes: 4 additions & 0 deletions PgBulk.Abstractions/ITableColumnInformation.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,12 @@

public interface ITableColumnInformation
{
int Index { get; }

string Name { get; }

string SafeName => Name.StartsWith('"') && Name.EndsWith('"') ? Name : $"\"{Name}\"";

bool PrimaryKey { get; }

bool ValueGeneratedOnAdd { get; }
Expand Down
6 changes: 6 additions & 0 deletions PgBulk.Abstractions/ITableKeyProvider.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
namespace PgBulk.Abstractions;

public interface ITableKeyProvider
{
TableKey GetKeyColumns(ITableInformation tableInformation);
}
2 changes: 1 addition & 1 deletion PgBulk.Abstractions/PgBulk.Abstractions.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
<Nullable>enable</Nullable>
<LangVersion>10</LangVersion>
<PackageProjectUrl>https://github.com/redbaty/PgBulk</PackageProjectUrl>
<PackageVersion>1.1.2</PackageVersion>
<PackageVersion>1.1.3</PackageVersion>
</PropertyGroup>

</Project>
3 changes: 3 additions & 0 deletions PgBulk.Abstractions/TableKey.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
namespace PgBulk.Abstractions;

public record TableKey(ICollection<ITableColumnInformation> Columns, bool IsUniqueConstraint);
11 changes: 9 additions & 2 deletions PgBulk.EFCore/ContextExtensions.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Microsoft.EntityFrameworkCore;
using PgBulk.Abstractions;

namespace PgBulk.EFCore;

Expand All @@ -10,10 +11,16 @@ public static Task BulkSyncAsync<T>(this DbContext dbContext, IEnumerable<T> ent
return @operator.SyncAsync(entities, deleteWhere);
}

public static Task BulkMergeAsync<T>(this DbContext dbContext, IEnumerable<T> entities, int? timeoutOverride = 600, bool useContextConnection = true) where T : class
public static Task BulkMergeAsync<T>(this DbContext dbContext, IEnumerable<T> entities, int? timeoutOverride = 600, bool useContextConnection = true, ITableKeyProvider? tableKeyProvider = null) where T : class
{
var @operator = new BulkEfOperator(dbContext, timeoutOverride, useContextConnection);
return @operator.MergeAsync(entities);
return @operator.MergeAsync(entities.ToList(), tableKeyProvider);
}

public static Task BulkMergeAsync<T>(this DbContext dbContext, ICollection<T> entities, int? timeoutOverride = 600, bool useContextConnection = true, ITableKeyProvider? tableKeyProvider = null) where T : class
{
var @operator = new BulkEfOperator(dbContext, timeoutOverride, useContextConnection);
return @operator.MergeAsync(entities, tableKeyProvider);
}

public static Task BulkInsertAsync<T>(this DbContext dbContext, IEnumerable<T> entities, int? timeoutOverride = 600, bool useContextConnection = true) where T : class
Expand Down
7 changes: 5 additions & 2 deletions PgBulk.EFCore/EntityColumnInformation.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,18 @@ namespace PgBulk.EFCore;

public class EntityColumnInformation : ITableColumnInformation
{
public EntityColumnInformation(string name, bool primaryKey, bool valueGeneratedOnAdd, PropertyInfo? property)
public EntityColumnInformation(string name, bool primaryKey, bool valueGeneratedOnAdd, PropertyInfo? property, int index)
{
Name = name;
PrimaryKey = primaryKey;
Property = property;
Index = index;
ValueGeneratedOnAdd = valueGeneratedOnAdd;
}

private PropertyInfo? Property { get; }
public PropertyInfo? Property { get; }

public int Index { get; }

public string Name { get; }

Expand Down
46 changes: 46 additions & 0 deletions PgBulk.EFCore/EntityManualTableKeyProvider.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
using System.Linq.Expressions;
using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.Infrastructure;
using PgBulk.Abstractions;

namespace PgBulk.EFCore;

public class EntityManualTableKeyProvider<TEntity> : ITableKeyProvider
{
private readonly ICollection<ITableColumnInformation> _primaryKeyColumns;

public EntityManualTableKeyProvider(ICollection<ITableColumnInformation> primaryKeyColumns)
{
_primaryKeyColumns = primaryKeyColumns;
}

public EntityManualTableKeyProvider()
{
_primaryKeyColumns = new List<ITableColumnInformation>();
}

public async ValueTask AddKeyColumn<TObj>(Expression<Func<TEntity, TObj>> propertyLambda, DbContext dbContext)
{
var entityTableInformationProvider = new EntityTableInformationProvider(dbContext);
var tableInformation = (EntityTableInformation)await entityTableInformationProvider.GetTableInformation(typeof(TEntity));
AddKeyColumn(propertyLambda, tableInformation);
}

public void AddKeyColumn<TObj>(Expression<Func<TEntity, TObj>> propertyLambda, EntityTableInformation tableInformation)
{
var property = propertyLambda.GetPropertyAccess();
var entityColumnInformation = tableInformation.Columns
.OfType<EntityColumnInformation>()
.SingleOrDefault(i => i.Property == property);

if (entityColumnInformation != null)
_primaryKeyColumns.Add(entityColumnInformation);
else
throw new InvalidOperationException($"Could not find column information for property {property.Name}");
}

public TableKey GetKeyColumns(ITableInformation tableInformation)
{
return new TableKey(_primaryKeyColumns, false);
}
}
2 changes: 1 addition & 1 deletion PgBulk.EFCore/EntityTableInformationProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ public Task<ITableInformation> GetTableInformation(Type entityType)
var columns = model
.GetProperties()
.Where(p => p.PropertyInfo != null)
.Select(p => new EntityColumnInformation(p.GetColumnName(storeObjectIdentifier) ?? p.Name, p.IsPrimaryKey(), p.ValueGenerated == ValueGenerated.OnAdd, p.PropertyInfo));
.Select((p, i) => new EntityColumnInformation(p.GetColumnName(storeObjectIdentifier) ?? p.Name, p.IsPrimaryKey(), p.ValueGenerated == ValueGenerated.OnAdd, p.PropertyInfo, i));
var entityTableInformation = new EntityTableInformation(model.GetSchema() ?? model.GetDefaultSchema() ?? "public", tableName, columns);

return Task.FromResult((ITableInformation)entityTableInformation);
Expand Down
2 changes: 1 addition & 1 deletion PgBulk.EFCore/PgBulk.EFCore.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
<Nullable>enable</Nullable>
<LangVersion>10</LangVersion>
<PackageProjectUrl>https://github.com/redbaty/PgBulk</PackageProjectUrl>
<PackageVersion>1.1.15</PackageVersion>
<PackageVersion>1.1.16</PackageVersion>
</PropertyGroup>

<ItemGroup>
Expand Down
38 changes: 38 additions & 0 deletions PgBulk.Tests/EFCoreTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,44 @@ public async Task Upsert(int value)

await myContext.Database.EnsureDeletedAsync();
}

[TestMethod]
[DataRow(100)]
[DataRow(1000)]
public async Task UpsertCustomKey(int value)
{
await using var myContext = CreateContext();
try
{
var customKeyProvider = new EntityManualTableKeyProvider<TestRow>();
await customKeyProvider.AddKeyColumn(i => i.Value1, myContext);

var testRows = Faker.Generate(value).OrderBy(i => i.Value1).ToArray();
await myContext.BulkMergeAsync(testRows, tableKeyProvider: customKeyProvider);

var currentCount = await myContext.TestRows.CountAsync();
Assert.AreEqual(value, currentCount);

var values = testRows.Select(i => i.Value1).Take(10).ToList();
var newRows = Faker
.RuleFor(x => x.Value1, f =>
{
var picked = f.PickRandom(values);
values.Remove(picked);
return picked;
})
.RuleFor(x => x.Id, f => f.IndexFaker + testRows.Length)
.Generate(10).OrderBy(i => i.Value1).ToArray();

await myContext.BulkMergeAsync(newRows, tableKeyProvider: customKeyProvider);
currentCount = await myContext.TestRows.CountAsync();
Assert.AreEqual(value, currentCount);
}
finally
{
await myContext.Database.EnsureDeletedAsync();
}
}

[TestMethod]
[DataRow(100)]
Expand Down
90 changes: 70 additions & 20 deletions PgBulk/BulkOperator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,14 @@ public virtual async Task<NpgsqlConnection> CreateOpenedConnection()
await connection.OpenAsync();
return connection;
}

public async Task MergeAsync<T>(IEnumerable<T> entities)
public async Task MergeAsync<T>(ICollection<T> entities, ITableKeyProvider? tableKeyProvider = null)
{
var connection = await CreateOpenedConnection();

try
{
await MergeAsync(connection, entities);
await MergeAsync(connection, entities, tableKeyProvider ?? new DefaultTableKeyProvider());
}
finally
{
Expand Down Expand Up @@ -75,7 +75,7 @@ private async Task<ulong> InsertToTableAsync<T>(NpgsqlConnection npgsqlConnectio
return await InsertToTableAsync(npgsqlConnection, entities, tableInformation, tableInformation.Name);
}

public virtual async Task MergeAsync<T>(NpgsqlConnection connection, IEnumerable<T> entities, Func<string, string, Task>? runAfterTemporaryTableInsert = null)
public virtual async Task MergeAsync<T>(NpgsqlConnection connection, ICollection<T> entities, ITableKeyProvider tableKeyProvider, Func<string, string, Task>? runAfterTemporaryTableInsert = null)
{
var tableInformation = await TableInformationProvider.GetTableInformation(typeof(T));
var temporaryName = GetTemporaryTableName(tableInformation);
Expand All @@ -85,29 +85,69 @@ public virtual async Task MergeAsync<T>(NpgsqlConnection connection, IEnumerable

if (runAfterTemporaryTableInsert != null) await runAfterTemporaryTableInsert(tableInformation.Name, temporaryName);

var primaryKeyColumns = tableInformation.Columns
.Where(i => i is { PrimaryKey: true, ValueGeneratedOnAdd: false })
.Select(i => $"\"{i.Name}\"")
var tableKey = tableKeyProvider.GetKeyColumns(tableInformation);
var primaryKeyColumns = tableKey
.Columns
.Select(i => i.SafeName)
.DefaultIfEmpty()
.Aggregate((x, y) => $"{x},{y}");

if (string.IsNullOrEmpty(primaryKeyColumns))
throw new InvalidOperationException($"No primary keys defined for table \"{tableInformation.Name}\"");

var setStatement = tableInformation.Columns
.Where(i => !i.PrimaryKey)
.Select(i => $"\"{i.Name}\" = EXCLUDED.\"{i.Name}\"")
.DefaultIfEmpty()
.Aggregate((x, y) => $"{x}, {y}");
if (tableKey.IsUniqueConstraint)
{
var setStatement = tableInformation.Columns
.Where(i => !i.PrimaryKey)
.Select(i => $"\"{i.Name}\" = EXCLUDED.\"{i.Name}\"")
.DefaultIfEmpty()
.Aggregate((x, y) => $"{x}, {y}");

var baseCommand = new StringBuilder($"insert into \"{tableInformation.Schema}\".\"{tableInformation.Name}\" (select * from \"{temporaryName}\") ON CONFLICT ");

var baseCommand = new StringBuilder($"insert into \"{tableInformation.Schema}\".\"{tableInformation.Name}\" (select * from \"{temporaryName}\") ON CONFLICT ");
if (!string.IsNullOrEmpty(setStatement))
baseCommand.Append($"({primaryKeyColumns}) DO UPDATE SET {setStatement}");
else
baseCommand.Append("DO NOTHING");

if (!string.IsNullOrEmpty(setStatement))
baseCommand.Append($"({primaryKeyColumns}) DO UPDATE SET {setStatement}");
await ExecuteCommand(connection, baseCommand.ToString());
}
else
baseCommand.Append("DO NOTHING");

await ExecuteCommand(connection, baseCommand.ToString());
{
await using var transaction = await connection.BeginTransactionAsync();
try
{
var deleteScriptBuilder = new StringBuilder($"delete from \"{tableInformation.Schema}\".\"{tableInformation.Name}\" where ");
var first = true;

foreach (var column in tableKey.Columns.OrderBy(i => i.Index))
{
if(!first)
deleteScriptBuilder.Append(" and ");

deleteScriptBuilder.Append($"{column.SafeName} = @p{column.Index}");
first = false;
}

var deleteScript = deleteScriptBuilder.ToString();

foreach (var entity in entities)
{
var npgsqlParameters = tableKey.Columns
.Select(i => new NpgsqlParameter($"p{i.Index}", i.GetValue(entity)))
.ToArray();

await ExecuteCommand(connection, deleteScript, npgsqlParameters);
}

await ExecuteCommand(connection, $"insert into \"{tableInformation.Schema}\".\"{tableInformation.Name}\" (select * from \"{temporaryName}\")");
await transaction.CommitAsync();
}
finally
{
await transaction.DisposeAsync();
}
}
}

private async Task CreateTemporaryTable(NpgsqlConnection connection, ITableInformation sourceTable, string temporaryName)
Expand All @@ -116,17 +156,27 @@ private async Task CreateTemporaryTable(NpgsqlConnection connection, ITableInfor
await ExecuteCommand(connection, script);
}

private async Task ExecuteCommand(NpgsqlConnection connection, string script)
private async Task<int> ExecuteCommand(NpgsqlConnection connection, string script, IEnumerable<NpgsqlParameter>? parameters = null)
{
await using var npgsqlCommand = connection.CreateCommand();
npgsqlCommand.CommandText = script;

if (parameters != null)
{
foreach (var parameter in parameters)
{
npgsqlCommand.Parameters.Add(parameter);
}
}

LogBeforeCommand(npgsqlCommand);
var stopWatch = Stopwatch.StartNew();
await npgsqlCommand.ExecuteNonQueryAsync();
var updatedRows = await npgsqlCommand.ExecuteNonQueryAsync();

stopWatch.Start();
LogAfterCommand(npgsqlCommand, stopWatch.Elapsed);

return updatedRows;
}

public async Task SyncAsync<T>(IEnumerable<T> entities, string? deleteWhere = null, Func<string, string, Task>? runAfterTemporaryTableInsert = null)
Expand Down
13 changes: 13 additions & 0 deletions PgBulk/DefaultTableKeyProvider.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
using PgBulk.Abstractions;

namespace PgBulk;

public class DefaultTableKeyProvider : ITableKeyProvider
{
public TableKey GetKeyColumns(ITableInformation tableInformation)
{
return new TableKey(tableInformation.Columns
.Where(i => i is { PrimaryKey: true, ValueGeneratedOnAdd: false })
.ToArray(), true);
}
}
10 changes: 7 additions & 3 deletions PgBulk/ManualTableColumnInformationBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,20 @@ public ManualTableColumnInformationBuilder(string tableName, string schema = "pu

public ManualTableColumnInformationBuilder<T> Automap()
{
foreach (var propertyInfo in typeof(T).GetProperties().Where(i => i is { CanRead: true, CanWrite: true }))
ColumnMappings.Add(new ManualTableColumnMapping(propertyInfo.Name, propertyInfo, false));
foreach (var propertyInfo in typeof(T).GetProperties().Where(i => i is { CanRead: true, CanWrite: true }))
{
var previousMax = ColumnMappings.Count < 1 ? 0 : ColumnMappings.Max(x => x.Index);
ColumnMappings.Add(new ManualTableColumnMapping(propertyInfo.Name, propertyInfo, false, previousMax + 1));
}

return this;
}

public ManualTableColumnInformationBuilder<T> Property<TObj>(Expression<Func<T, TObj>> propertyLambda, string columnName, bool primaryKey = false)
{
var propertyInfo = propertyLambda.GetProperty();
var columnMapping = new ManualTableColumnMapping(columnName, propertyInfo, primaryKey);
var previousMax = ColumnMappings.Count < 1 ? 0 : ColumnMappings.Max(x => x.Index);
var columnMapping = new ManualTableColumnMapping(columnName, propertyInfo, primaryKey, previousMax + 1);
ColumnMappings.Add(columnMapping);

return this;
Expand Down
7 changes: 5 additions & 2 deletions PgBulk/ManualTableColumnMapping.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,19 @@ namespace PgBulk;

public record ManualTableColumnMapping : ITableColumnInformation
{
public ManualTableColumnMapping(string name, PropertyInfo? property, bool valueGeneratedOnAdd, bool primaryKey = false)
public ManualTableColumnMapping(string name, PropertyInfo? property, bool valueGeneratedOnAdd, int index, bool primaryKey = false)
{
Name = name;
Property = property;
ValueGeneratedOnAdd = valueGeneratedOnAdd;
Index = index;
PrimaryKey = primaryKey;
}

internal PropertyInfo? Property { get; }


public int Index { get; }

public string Name { get; }

public bool PrimaryKey { get; internal set; }
Expand Down
Loading

0 comments on commit 9093148

Please sign in to comment.