Skip to content

Commit

Permalink
Add Entry method to DbSet to get entries for shared-type entity type …
Browse files Browse the repository at this point in the history
…instances (#28451)
  • Loading branch information
ajcvickers committed Jul 15, 2022
1 parent b767671 commit 5d64060
Show file tree
Hide file tree
Showing 5 changed files with 779 additions and 45 deletions.
4 changes: 3 additions & 1 deletion src/EFCore/ChangeTracking/LocalView.cs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ public class LocalView<TEntity> :
private ObservableBackedBindingList<TEntity>? _bindingList;
private ObservableCollection<TEntity>? _observable;
private readonly DbContext _context;
private readonly IEntityType _entityType;
private int _countChanges;
private int? _count;
private bool _triggeringStateManagerChange;
Expand All @@ -72,6 +73,7 @@ public class LocalView<TEntity> :
public LocalView(DbSet<TEntity> set)
{
_context = set.GetService<ICurrentDbContext>().Context;
_entityType = set.EntityType;

set.GetService<ILocalViewListener>().RegisterView(StateManagerChangedHandler);
}
Expand Down Expand Up @@ -208,7 +210,7 @@ public virtual void Add(TEntity item)
// to Add it again since doing so would change its state to Added, which is probably not what
// was wanted in this case.

var entry = _context.GetDependencies().StateManager.GetOrCreateEntry(item);
var entry = _context.GetDependencies().StateManager.GetOrCreateEntry(item, _entityType);
if (entry.EntityState == EntityState.Deleted
|| entry.EntityState == EntityState.Detached)
{
Expand Down
13 changes: 13 additions & 0 deletions src/EFCore/DbSet.cs
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,19 @@ public virtual void RemoveRange(IEnumerable<TEntity> entities)
public virtual void UpdateRange(IEnumerable<TEntity> entities)
=> throw new NotSupportedException();

/// <summary>
/// Gets an <see cref="EntityEntry{TEntity}" /> for the given entity. The entry provides
/// access to change tracking information and operations for the entity.
/// </summary>
/// <remarks>
/// See <see href="https://aka.ms/efcore-docs-entity-entries">Accessing tracked entities in EF Core</see> for more information and
/// examples.
/// </remarks>
/// <param name="entity">The entity to get the entry for.</param>
/// <returns>The entry for the given entity.</returns>
public virtual EntityEntry<TEntity> Entry(TEntity entity)
=> throw new NotSupportedException();

/// <summary>
/// Returns an <see cref="IEnumerator{T}" /> which when enumerated will execute a query against the database
/// to load all entities from the database.
Expand Down
22 changes: 21 additions & 1 deletion src/EFCore/Internal/InternalDbSet.cs
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ public override async Task AddRangeAsync(
foreach (var entity in entities)
{
await SetEntityStateAsync(
stateManager.GetOrCreateEntry(entity),
stateManager.GetOrCreateEntry(entity, EntityType),
EntityState.Added,
cancellationToken)
.ConfigureAwait(false);
Expand Down Expand Up @@ -426,6 +426,26 @@ public override void RemoveRange(IEnumerable<TEntity> entities)
public override void UpdateRange(IEnumerable<TEntity> entities)
=> SetEntityStates(Check.NotNull(entities, nameof(entities)), EntityState.Modified);

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public override EntityEntry<TEntity> Entry(TEntity entity)
{
Check.NotNull(entity, nameof(entity));

var entry = EntryWithoutDetectChanges(entity);

if (_context.ChangeTracker.AutoDetectChangesEnabled)
{
entry.DetectChanges();
}

return entry;
}

private IEntityFinder<TEntity> Finder
=> (IEntityFinder<TEntity>)_context.GetDependencies().EntityFinderFactory.Create(EntityType);

Expand Down
43 changes: 0 additions & 43 deletions test/EFCore.Tests/DbSetTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -133,37 +133,6 @@ public void Direct_use_of_Set_throws_if_context_disposed()
Assert.Throws<ObjectDisposedException>(() => context.Set<Category>()).Message);
}

[ConditionalFact]
public void Direct_use_of_Set_for_shared_type_throws_if_context_disposed()
{
var context = new EarlyLearningCenter();
context.Dispose();

Assert.StartsWith(
CoreStrings.ContextDisposed,
Assert.Throws<ObjectDisposedException>(() => context.Set<Dictionary<string, object>>("SharedTypeEntityTypeName")).Message);
}

[ConditionalFact]
public void Using_shared_type_entity_type_db_set_with_incorrect_return_type_throws()
{
using var context = new EarlyLearningCenter();

var dbSet = context.Set<Dictionary<string, object>>("SharedEntity");

Assert.NotNull(dbSet.Add(new Dictionary<string, object> { { "Id", 1 } }));
Assert.NotNull(dbSet.ToList());

var wrongDbSet = context.Set<Category>("SharedEntity");

Assert.Equal(
CoreStrings.DbSetIncorrectGenericType("SharedEntity", "Dictionary<string, object>", "Category"),
Assert.Throws<InvalidOperationException>(() => wrongDbSet.Add(new Category())).Message);
Assert.Equal(
CoreStrings.DbSetIncorrectGenericType("SharedEntity", "Dictionary<string, object>", "Category"),
Assert.Throws<InvalidOperationException>(() => wrongDbSet.ToList()).Message);
}

[ConditionalFact]
public void Use_of_LocalView_throws_if_context_is_disposed()
{
Expand Down Expand Up @@ -735,11 +704,6 @@ public async Task Can_enumerate_with_await_foreach_with_cancellation()
}
}

private class Curious
{
public string George { get; set; }
}

private class Category
{
public int Id { get; set; }
Expand Down Expand Up @@ -769,12 +733,5 @@ protected internal override void OnConfiguring(DbContextOptionsBuilder optionsBu
=> optionsBuilder
.UseInMemoryDatabase(Guid.NewGuid().ToString())
.UseInternalServiceProvider(InMemoryTestHelpers.Instance.CreateServiceProvider());

protected internal override void OnModelCreating(ModelBuilder modelBuilder)
{
base.OnModelCreating(modelBuilder);

modelBuilder.SharedTypeEntity<Dictionary<string, object>>("SharedEntity").IndexerProperty<int>("Id");
}
}
}
Loading

0 comments on commit 5d64060

Please sign in to comment.