Skip to content

Commit

Permalink
Use conventions and model snapshot in the migration tests
Browse files Browse the repository at this point in the history
  • Loading branch information
roji committed Dec 17, 2021
1 parent 0e1e95b commit e190686
Show file tree
Hide file tree
Showing 10 changed files with 516 additions and 158 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@

<ItemGroup>
<PackageReference Include="Microsoft.CodeAnalysis.CSharp" Version="$(MicrosoftCodeAnalysisVersion)" />
<PackageReference Include="Microsoft.Extensions.DependencyModel" Version="$(MicrosoftExtensionsDependencyModelVersion)" />
</ItemGroup>

<ItemGroup>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using Microsoft.EntityFrameworkCore.Diagnostics.Internal;
using Microsoft.EntityFrameworkCore.Metadata.Internal;
using Microsoft.EntityFrameworkCore.Migrations.Internal;
using Microsoft.EntityFrameworkCore.Scaffolding.Metadata;

namespace Microsoft.EntityFrameworkCore.Migrations;
Expand Down Expand Up @@ -1141,7 +1141,7 @@ public virtual Task Rename_index()
});

[ConditionalFact]
public virtual Task Add_primary_key()
public virtual Task Add_primary_key_int()
=> Test(
builder => builder.Entity("People").Property<int>("SomeField"),
builder => { },
Expand All @@ -1159,10 +1159,29 @@ public virtual Task Add_primary_key()
}
});

[ConditionalFact]
public virtual Task Add_primary_key_string()
=> Test(
builder => builder.Entity("People").Property<string>("SomeField").IsRequired(),
builder => { },
builder => builder.Entity("People").HasKey("SomeField"),
model =>
{
var table = Assert.Single(model.Tables);
var primaryKey = table.PrimaryKey;
Assert.NotNull(primaryKey);
Assert.Same(table, primaryKey!.Table);
Assert.Same(table.Columns.Single(), Assert.Single(primaryKey.Columns));
if (AssertConstraintNames)
{
Assert.Equal("PK_People", primaryKey.Name);
}
});

[ConditionalFact]
public virtual Task Add_primary_key_with_name()
=> Test(
builder => builder.Entity("People").Property<int>("SomeField"),
builder => builder.Entity("People").Property<string>("SomeField"),
builder => { },
builder => builder.Entity("People").HasKey("SomeField").HasName("PK_Foo"),
model =>
Expand Down Expand Up @@ -1206,13 +1225,21 @@ public virtual Task Add_primary_key_composite_with_name()
});

[ConditionalFact]
public virtual Task Drop_primary_key()
public virtual Task Drop_primary_key_int()
=> Test(
builder => builder.Entity("People").Property<int>("SomeField"),
builder => builder.Entity("People").HasKey("SomeField"),
builder => { },
model => Assert.Null(Assert.Single(model.Tables).PrimaryKey));

[ConditionalFact]
public virtual Task Drop_primary_key_string()
=> Test(
builder => builder.Entity("People").Property<string>("SomeField").IsRequired(),
builder => builder.Entity("People").HasKey("SomeField"),
builder => { },
model => Assert.Null(Assert.Single(model.Tables).PrimaryKey));

[ConditionalFact]
public virtual Task Add_foreign_key()
=> Test(
Expand Down Expand Up @@ -1244,7 +1271,7 @@ public virtual Task Add_foreign_key()
Assert.Equal("FK_Orders_Customers_CustomerId", foreignKey.Name);
}
Assert.Equal(ReferentialAction.NoAction, foreignKey.OnDelete);
Assert.Equal(ReferentialAction.Cascade, foreignKey.OnDelete);
Assert.Same(customersTable, foreignKey.PrincipalTable);
Assert.Same(customersTable.Columns.Single(), Assert.Single(foreignKey.PrincipalColumns));
Assert.Equal("CustomerId", Assert.Single(foreignKey.Columns).Name);
Expand Down Expand Up @@ -1745,62 +1772,62 @@ protected virtual IRelationalTypeMappingSource TypeMappingSource
protected virtual Task Test(
Action<ModelBuilder> buildSourceAction,
Action<ModelBuilder> buildTargetAction,
Action<DatabaseModel> asserter)
=> Test(b => { }, buildSourceAction, buildTargetAction, asserter);
Action<DatabaseModel> asserter,
bool withConventions = true)
=> Test(_ => { }, buildSourceAction, buildTargetAction, asserter, withConventions);

protected virtual Task Test(
Action<ModelBuilder> buildCommonAction,
Action<ModelBuilder> buildSourceAction,
Action<ModelBuilder> buildTargetAction,
Action<DatabaseModel> asserter)
Action<DatabaseModel> asserter,
bool withConventions = true)
{
var context = CreateContext();
var modelDiffer = context.GetService<IMigrationsModelDiffer>();
var modelRuntimeInitializer = context.GetService<IModelRuntimeInitializer>();

// Build the source and target models. Add current/latest product version if one wasn't set.
var sourceModelBuilder = CreateConventionlessModelBuilder();
// Build the source model, possibly with conventions
var sourceModelBuilder = CreateModelBuilder(withConventions);
buildCommonAction(sourceModelBuilder);
buildSourceAction(sourceModelBuilder);
var sourceModel = modelRuntimeInitializer.Initialize(
var preSnapshotSourceModel = modelRuntimeInitializer.Initialize(
(IModel)sourceModelBuilder.Model, designTime: true, validationLogger: null);

var targetModelBuilder = CreateConventionlessModelBuilder();
// Round-trip the source model through a snapshot, compiling it and then extracting it back again.
// This simulates the real-world migration flow and can expose errors in snapshot generation
var migrationsCodeGenerator = Fixture.TestHelpers.CreateDesignServiceProvider().GetRequiredService<IMigrationsCodeGenerator>();
var sourceModelSnapshot = migrationsCodeGenerator.GenerateSnapshot(
modelSnapshotNamespace: null, typeof(DbContext), "MigrationsTestSnapshot", preSnapshotSourceModel);
var sourceModel = BuildModelFromSnapshotSource(sourceModelSnapshot);

// Build the target model, possibly with conventions
var targetModelBuilder = CreateModelBuilder(withConventions);
buildCommonAction(targetModelBuilder);
buildTargetAction(targetModelBuilder);

var targetModel = modelRuntimeInitializer.Initialize(
(IModel)targetModelBuilder.Model, designTime: true, validationLogger: null);

// Get the migration operations between the two models and test
var operations = modelDiffer.GetDifferences(sourceModel.GetRelationalModel(), targetModel.GetRelationalModel());

return Test(sourceModel, targetModel, operations, asserter);
}

protected DiagnosticsLogger<DbLoggerCategory.Model.Validation> CreateValidationLogger(bool sensitiveDataLoggingEnabled = false)
{
var options = new LoggingOptions();
options.Initialize(new DbContextOptionsBuilder().EnableSensitiveDataLogging(sensitiveDataLoggingEnabled).Options);
return new DiagnosticsLogger<DbLoggerCategory.Model.Validation>(
Fixture.TestSqlLoggerFactory,
options,
new DiagnosticListener("Fake"),
Fixture.TestHelpers.LoggingDefinitions,
new NullDbContextLogger());
}

protected virtual Task Test(
Action<ModelBuilder> buildSourceAction,
MigrationOperation operation,
Action<DatabaseModel> asserter)
=> Test(buildSourceAction, new[] { operation }, asserter);
Action<DatabaseModel> asserter,
bool withConventions = true)
=> Test(buildSourceAction, new[] { operation }, asserter, withConventions);

protected virtual Task Test(
Action<ModelBuilder> buildSourceAction,
IReadOnlyList<MigrationOperation> operations,
Action<DatabaseModel> asserter)
Action<DatabaseModel> asserter,
bool withConventions = true)
{
var sourceModelBuilder = CreateConventionlessModelBuilder();
var sourceModelBuilder = CreateModelBuilder(withConventions);
buildSourceAction(sourceModelBuilder);
if (sourceModelBuilder.Model.GetProductVersion() is null)
{
Expand All @@ -1809,9 +1836,16 @@ protected virtual Task Test(

var context = CreateContext();
var modelRuntimeInitializer = context.GetService<IModelRuntimeInitializer>();
var sourceModel = modelRuntimeInitializer.Initialize(
var preSnapshotSourceModel = modelRuntimeInitializer.Initialize(
(IModel)sourceModelBuilder.Model, designTime: true, validationLogger: null);

// Round-trip the source model through a snapshot, compiling it and then extracting it back again.
// This simulates the real-world migration flow and can expose errors in snapshot generation
var migrationsCodeGenerator = Fixture.TestHelpers.CreateDesignServiceProvider().GetRequiredService<IMigrationsCodeGenerator>();
var sourceModelSnapshot = migrationsCodeGenerator.GenerateSnapshot(
modelSnapshotNamespace: null, typeof(DbContext), "MigrationsTestSnapshot", preSnapshotSourceModel);
var sourceModel = BuildModelFromSnapshotSource(sourceModelSnapshot);

return Test(sourceModel, targetModel: null, operations, asserter);
}

Expand Down Expand Up @@ -1860,16 +1894,18 @@ await migrationsCommandExecutor.ExecuteNonQueryAsync(

protected virtual Task<T> TestThrows<T>(
Action<ModelBuilder> buildSourceAction,
Action<ModelBuilder> buildTargetAction)
Action<ModelBuilder> buildTargetAction,
bool withConventions = true)
where T : Exception
=> TestThrows<T>(b => { }, buildSourceAction, buildTargetAction);
=> TestThrows<T>(b => { }, buildSourceAction, buildTargetAction, withConventions);

protected virtual Task<T> TestThrows<T>(
Action<ModelBuilder> buildCommonAction,
Action<ModelBuilder> buildSourceAction,
Action<ModelBuilder> buildTargetAction)
Action<ModelBuilder> buildTargetAction,
bool withConventions = true)
where T : Exception
=> Assert.ThrowsAsync<T>(() => Test(buildCommonAction, buildSourceAction, buildTargetAction, asserter: null));
=> Assert.ThrowsAsync<T>(() => Test(buildCommonAction, buildSourceAction, buildTargetAction, asserter: null, withConventions));

protected virtual void AssertSql(params string[] expected)
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);
Expand All @@ -1882,6 +1918,47 @@ public TestSqlLoggerFactory TestSqlLoggerFactory
=> (TestSqlLoggerFactory)ListLoggerFactory;
}

protected virtual ModelBuilder CreateConventionlessModelBuilder()
=> new(new ConventionSet());
private ModelBuilder CreateModelBuilder(bool withConventions)
=> withConventions ? Fixture.TestHelpers.CreateConventionBuilder() : new ModelBuilder(new ConventionSet());

protected IModel BuildModelFromSnapshotSource(string code)
{
var build = new BuildSource { Sources = { { "Snapshot.cs", code } } };

// Add standard EF references, a reference to the provider's assembly, and any extra references added by the provider's test suite
build.References.Add(BuildReference.ByName("Microsoft.EntityFrameworkCore"));
build.References.Add(BuildReference.ByName("Microsoft.EntityFrameworkCore.Relational"));

var databaseProvider = Fixture.TestHelpers.CreateContextServices().GetRequiredService<IDatabaseProvider>();
build.References.Add(BuildReference.ByName(databaseProvider.Name));

foreach (var buildReference in GetAdditionalReferences())
{
build.References.Add(buildReference);
}

var assembly = build.BuildInMemory();
var factoryType = assembly.GetType("MigrationsTestSnapshot");

var buildModelMethod = factoryType.GetMethod(
"BuildModel",
BindingFlags.Instance | BindingFlags.NonPublic,
null,
new[] { typeof(ModelBuilder) },
null);

var builder = new ModelBuilder();
builder.Model.RemoveAnnotation(CoreAnnotationNames.ProductVersion);

buildModelMethod.Invoke(
Activator.CreateInstance(factoryType),
new object[] { builder });

var services = Fixture.TestHelpers.CreateContextServices();
var processor = new SnapshotModelProcessor(new TestOperationReporter(), services.GetService<IModelRuntimeInitializer>());
return processor.Process(builder.Model);
}

protected virtual ICollection<BuildReference> GetAdditionalReferences()
=> Array.Empty<BuildReference>();
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ where IOPath.GetFileNameWithoutExtension(r) == name
if (references.Count == 0)
{
throw new InvalidOperationException(
$"Assembly '{name}' not found.");
$"Assembly '{name}' not found. " +
"You may be missing '<PreserveCompilationContext>true</PreserveCompilationContext>' in your test project's csproj.");
}

return new BuildReference(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
<TargetFramework>net6.0</TargetFramework>
<AssemblyName>Microsoft.EntityFrameworkCore.SqlServer.FunctionalTests</AssemblyName>
<RootNamespace>Microsoft.EntityFrameworkCore</RootNamespace>
<PreserveCompilationContext>true</PreserveCompilationContext>
<SkipTests Condition="'$(OS)' != 'Windows_NT' AND '$(Test__SqlServer__DefaultConnection)' == ''">True</SkipTests>
<ImplicitUsings>true</ImplicitUsings>
</PropertyGroup>
Expand Down
Loading

0 comments on commit e190686

Please sign in to comment.