Skip to content

Commit

Permalink
Adds onConflictIgnore support
Browse files Browse the repository at this point in the history
This makes it so you can run BulkInsert and ignore if there's a conflict
  • Loading branch information
redbaty committed Nov 6, 2023
1 parent 9093148 commit 38d983a
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 26 deletions.
4 changes: 2 additions & 2 deletions PgBulk.EFCore/ContextExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ public static Task BulkMergeAsync<T>(this DbContext dbContext, ICollection<T> en
return @operator.MergeAsync(entities, tableKeyProvider);
}

public static Task BulkInsertAsync<T>(this DbContext dbContext, IEnumerable<T> entities, int? timeoutOverride = 600, bool useContextConnection = true) where T : class
public static Task BulkInsertAsync<T>(this DbContext dbContext, IEnumerable<T> entities, int? timeoutOverride = 600, bool useContextConnection = true, bool onConflictIgnore = false) where T : class
{
var @operator = new BulkEfOperator(dbContext, timeoutOverride, useContextConnection);
return @operator.InsertAsync(entities);
return @operator.InsertAsync(entities, onConflictIgnore);
}

public static BulkEfOperator GetBulkOperator(this DbContext dbContext, int? timeoutOverride = 600, bool useContextConnection = true)
Expand Down
2 changes: 1 addition & 1 deletion PgBulk.EFCore/PgBulk.EFCore.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
<Nullable>enable</Nullable>
<LangVersion>10</LangVersion>
<PackageProjectUrl>https://github.com/redbaty/PgBulk</PackageProjectUrl>
<PackageVersion>1.1.16</PackageVersion>
<PackageVersion>1.1.17</PackageVersion>
</PropertyGroup>

<ItemGroup>
Expand Down
25 changes: 25 additions & 0 deletions PgBulk.Tests/EFCoreTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,31 @@ public async Task Insert(int value)

await myContext.Database.EnsureDeletedAsync();
}

[TestMethod]
[DataRow(100)]
[DataRow(1000)]
public async Task InsertConflictIgnore(int value)
{
await using var myContext = CreateContext();
var testRows = Faker.Generate(value);
await myContext.BulkInsertAsync(testRows);

var currentCount = await myContext.TestRows.CountAsync();
Assert.AreEqual(value, currentCount);

var newRows = Faker
.RuleFor(x => x.Id, f => f.IndexFaker + testRows.Count - 5)
.Generate(10)
.ToArray();

await myContext.BulkInsertAsync(newRows, onConflictIgnore: true);
currentCount = await myContext.TestRows.CountAsync();

Assert.AreEqual(value + 5, currentCount);

await myContext.Database.EnsureDeletedAsync();
}

[TestMethod]
[DataRow(100)]
Expand Down
56 changes: 34 additions & 22 deletions PgBulk/BulkOperator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public virtual async Task<NpgsqlConnection> CreateOpenedConnection()
await connection.OpenAsync();
return connection;
}

public async Task MergeAsync<T>(ICollection<T> entities, ITableKeyProvider? tableKeyProvider = null)
{
var connection = await CreateOpenedConnection();
Expand All @@ -54,13 +54,13 @@ public async Task MergeAsync<T>(ICollection<T> entities, ITableKeyProvider? tabl
}
}

public async Task InsertAsync<T>(IEnumerable<T> entities)
public async Task InsertAsync<T>(IEnumerable<T> entities, bool onConflictIgnore)
{
var connection = await CreateOpenedConnection();

try
{
await InsertToTableAsync(connection, entities);
await InsertToTableAsync(connection, entities, onConflictIgnore);
}
finally
{
Expand All @@ -69,10 +69,10 @@ public async Task InsertAsync<T>(IEnumerable<T> entities)
}
}

private async Task<ulong> InsertToTableAsync<T>(NpgsqlConnection npgsqlConnection, IEnumerable<T> entities)
private async Task<ulong> InsertToTableAsync<T>(NpgsqlConnection npgsqlConnection, IEnumerable<T> entities, bool onConflictIgnore)
{
var tableInformation = await TableInformationProvider.GetTableInformation(typeof(T));
return await InsertToTableAsync(npgsqlConnection, entities, tableInformation, tableInformation.Name);
return await InsertToTableAsync(npgsqlConnection, entities, tableInformation, tableInformation.Name, onConflictIgnore);
}

public virtual async Task MergeAsync<T>(NpgsqlConnection connection, ICollection<T> entities, ITableKeyProvider tableKeyProvider, Func<string, string, Task>? runAfterTemporaryTableInsert = null)
Expand All @@ -81,7 +81,7 @@ public virtual async Task MergeAsync<T>(NpgsqlConnection connection, ICollection
var temporaryName = GetTemporaryTableName(tableInformation);

await CreateTemporaryTable(connection, tableInformation, temporaryName);
await InsertToTableAsync(connection, entities, tableInformation, temporaryName);
await InsertToTableAsync(connection, entities, tableInformation, temporaryName, false);

if (runAfterTemporaryTableInsert != null) await runAfterTemporaryTableInsert(tableInformation.Name, temporaryName);

Expand Down Expand Up @@ -119,27 +119,27 @@ public virtual async Task MergeAsync<T>(NpgsqlConnection connection, ICollection
{
var deleteScriptBuilder = new StringBuilder($"delete from \"{tableInformation.Schema}\".\"{tableInformation.Name}\" where ");
var first = true;

foreach (var column in tableKey.Columns.OrderBy(i => i.Index))
{
if(!first)
if (!first)
deleteScriptBuilder.Append(" and ");

deleteScriptBuilder.Append($"{column.SafeName} = @p{column.Index}");
first = false;
}

var deleteScript = deleteScriptBuilder.ToString();

foreach (var entity in entities)
{
var npgsqlParameters = tableKey.Columns
.Select(i => new NpgsqlParameter($"p{i.Index}", i.GetValue(entity)))
.ToArray();

await ExecuteCommand(connection, deleteScript, npgsqlParameters);
}

await ExecuteCommand(connection, $"insert into \"{tableInformation.Schema}\".\"{tableInformation.Name}\" (select * from \"{temporaryName}\")");
await transaction.CommitAsync();
}
Expand Down Expand Up @@ -168,7 +168,7 @@ private async Task<int> ExecuteCommand(NpgsqlConnection connection, string scrip
npgsqlCommand.Parameters.Add(parameter);
}
}

LogBeforeCommand(npgsqlCommand);
var stopWatch = Stopwatch.StartNew();
var updatedRows = await npgsqlCommand.ExecuteNonQueryAsync();
Expand Down Expand Up @@ -200,7 +200,7 @@ public virtual async Task SyncAsync<T>(NpgsqlConnection connection, IEnumerable<
var temporaryName = GetTemporaryTableName(tableInformation);

await CreateTemporaryTable(connection, tableInformation, temporaryName);
await InsertToTableAsync(connection, entities, tableInformation, temporaryName);
await InsertToTableAsync(connection, entities, tableInformation, temporaryName, false);

if (runAfterTemporaryTableInsert != null) await runAfterTemporaryTableInsert(tableInformation.Name, temporaryName);

Expand Down Expand Up @@ -248,32 +248,44 @@ public async Task<NpgsqlBinaryImporter<T>> CreateBinaryImporterAsync<T>(NpgsqlCo
public async Task<NpgsqlBinaryImporter<T>> CreateBinaryImporterAsync<T>(NpgsqlConnection connection, IEnumerable<ITableColumnInformation> columns, string targetTableName, string? targetSchema = null)
{
var columnsFiltered = columns.Where(i => !i.ValueGeneratedOnAdd).ToList();
if(columnsFiltered.Count <= 0)

if (columnsFiltered.Count <= 0)
throw new InvalidOperationException("No valid columns found on type " + typeof(T).Name);

var columnsString = columnsFiltered
.Select(i => $"\"{i.Name}\"")
.Aggregate((x, y) => $"{x}, {y}");

var commandBuilder = new StringBuilder("COPY ");

if (!string.IsNullOrEmpty(targetSchema))
{
commandBuilder.Append($"\"{targetSchema}\".");
}

commandBuilder.Append($"\"{targetTableName}\" ({columnsString}) FROM STDIN (FORMAT BINARY)");

var command = commandBuilder.ToString();

#if NET5_0
return await Task.FromResult(new NpgsqlBinaryImporter<T>(connection.BeginBinaryImport(command), columnsFiltered));
#else
return new NpgsqlBinaryImporter<T>(await connection.BeginBinaryImportAsync(command), columnsFiltered);
#endif
}

private async Task<ulong> InsertToTableAsync<T>(NpgsqlConnection connection, IEnumerable<T> entities, ITableInformation tableInformation, string tableName, bool onConflictIgnore)
{
if (!onConflictIgnore) return await InsertToTableAsync(connection, entities, tableInformation, tableName);

var temporaryName = GetTemporaryTableName(tableInformation);
await CreateTemporaryTable(connection, tableInformation, temporaryName);
await InsertToTableAsync(connection, entities, tableInformation, temporaryName);
var count = await ExecuteCommand(connection, $"insert into \"{tableInformation.Schema}\".\"{tableInformation.Name}\" (select * from \"{temporaryName}\") ON CONFLICT DO NOTHING");
return (ulong)count;

}

private async Task<ulong> InsertToTableAsync<T>(NpgsqlConnection connection, IEnumerable<T> entities, ITableInformation tableInformation, string tableName)
{
await using var npgsqlBinaryImporter = await CreateBinaryImporterAsync<T>(connection, tableInformation.Columns, tableName);
Expand Down
2 changes: 1 addition & 1 deletion PgBulk/PgBulk.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
<Nullable>enable</Nullable>
<LangVersion>10</LangVersion>
<PackageProjectUrl>https://github.com/redbaty/PgBulk</PackageProjectUrl>
<PackageVersion>1.1.11</PackageVersion>
<PackageVersion>1.1.12</PackageVersion>
</PropertyGroup>

<ItemGroup Condition="'$(TargetFramework)' == 'net6.0'">
Expand Down

0 comments on commit 38d983a

Please sign in to comment.