Skip to content

Commit

Permalink
Parse simple default constraint literals when scaffolding (#30927)
Browse files Browse the repository at this point in the history
  • Loading branch information
ajcvickers authored Jun 12, 2023
1 parent 27d1665 commit 101e88e
Show file tree
Hide file tree
Showing 10 changed files with 861 additions and 131 deletions.
17 changes: 17 additions & 0 deletions src/EFCore.Design/Extensions/ScaffoldingModelExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -818,6 +818,23 @@ public static IEnumerable<AttributeCodeFragment> GetDataAnnotations(
{
FluentApiCodeFragment? root = null;

if (annotatable is IProperty property
&& annotations.TryGetValue(RelationalAnnotationNames.DefaultValueSql, out _)
&& annotations.TryGetValue(RelationalAnnotationNames.DefaultValue, out var parsedAnnotation))
{
if (Equals(property.ClrType.GetDefaultValue(), parsedAnnotation.Value))
{
// Default value is CLR default for property, so exclude it from scaffolded model
annotations.Remove(RelationalAnnotationNames.DefaultValueSql);
annotations.Remove(RelationalAnnotationNames.DefaultValue);
}
else
{
// SQL was parsed, so use parsed value and exclude raw value
annotations.Remove(RelationalAnnotationNames.DefaultValueSql);
}
}

foreach (var methodCall in annotationCodeGenerator.GenerateFluentApiCalls(annotatable, annotations))
{
var fluentApiCall = FluentApiCodeFragment.From(methodCall);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,11 @@ protected virtual EntityTypeBuilder VisitColumns(EntityTypeBuilder builder, ICol
property.ValueGeneratedOnAddOrUpdate();
}

if (column.DefaultValue != null)
{
property.HasDefaultValue(column.DefaultValue);
}

if (column.DefaultValueSql != null)
{
property.HasDefaultValueSql(column.DefaultValueSql);
Expand Down
5 changes: 5 additions & 0 deletions src/EFCore.Relational/Scaffolding/Metadata/DatabaseColumn.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ public class DatabaseColumn : Annotatable
/// </summary>
public virtual string? StoreType { get; set; }

/// <summary>
/// The default value for the column, or <see langword="null" /> if there is no default value or it could not be parsed.
/// </summary>
public virtual object? DefaultValue { get; set; }

/// <summary>
/// The default constraint for the column, or <see langword="null" /> if none.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,12 +184,6 @@ public override IReadOnlyList<MethodCallCodeFragment> GenerateFluentApiCalls(
return fragments;
}

/// <inheritdoc />
public override IReadOnlyList<MethodCallCodeFragment> GenerateFluentApiCalls(
IRelationalPropertyOverrides overrides,
IDictionary<string, IAnnotation> annotations)
=> base.GenerateFluentApiCalls(overrides, annotations);

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ namespace Microsoft.EntityFrameworkCore.SqlServer.Scaffolding.Internal;
public class SqlServerDatabaseModelFactory : DatabaseModelFactory
{
private readonly IDiagnosticsLogger<DbLoggerCategory.Scaffolding> _logger;
private readonly IRelationalTypeMappingSource _typeMappingSource;

private static readonly ISet<string> DateTimePrecisionTypes = new HashSet<string>
{
Expand Down Expand Up @@ -82,9 +83,12 @@ private static readonly Regex PartExtractor
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public SqlServerDatabaseModelFactory(IDiagnosticsLogger<DbLoggerCategory.Scaffolding> logger)
public SqlServerDatabaseModelFactory(
IDiagnosticsLogger<DbLoggerCategory.Scaffolding> logger,
IRelationalTypeMappingSource typeMappingSource)
{
_logger = logger;
_typeMappingSource = typeMappingSource;
}

/// <summary>
Expand Down Expand Up @@ -788,7 +792,7 @@ FROM [sys].[views] v
var scale = dataRecord.GetValueOrDefault<int>("scale");
var nullable = dataRecord.GetValueOrDefault<bool>("is_nullable");
var isIdentity = dataRecord.GetValueOrDefault<bool>("is_identity");
var defaultValue = dataRecord.GetValueOrDefault<string>("default_sql");
var defaultValueSql = dataRecord.GetValueOrDefault<string>("default_sql");
var computedValue = dataRecord.GetValueOrDefault<string>("computed_sql");
var computedIsPersisted = dataRecord.GetValueOrDefault<bool>("computed_is_persisted");
var comment = dataRecord.GetValueOrDefault<string>("comment");
Expand All @@ -811,7 +815,7 @@ FROM [sys].[views] v
scale,
nullable,
isIdentity,
defaultValue,
defaultValueSql,
computedValue,
computedIsPersisted);

Expand All @@ -830,15 +834,14 @@ FROM [sys].[views] v
systemTypeName = dataTypeName;
}

defaultValue = FilterClrDefaults(systemTypeName, nullable, defaultValue);

var column = new DatabaseColumn
{
Table = table,
Name = columnName,
StoreType = storeType,
IsNullable = nullable,
DefaultValueSql = defaultValue,
DefaultValue = TryParseClrDefault(systemTypeName, defaultValueSql),
DefaultValueSql = defaultValueSql,
ComputedColumnSql = computedValue,
IsStored = computedIsPersisted,
Comment = comment,
Expand Down Expand Up @@ -868,48 +871,110 @@ FROM [sys].[views] v
}
}

private static string? FilterClrDefaults(string dataTypeName, bool nullable, string? defaultValue)
private object? TryParseClrDefault(string dataTypeName, string? defaultValueSql)
{
if (defaultValue == null
|| defaultValue == "(NULL)")
defaultValueSql = defaultValueSql?.Trim();
if (string.IsNullOrEmpty(defaultValueSql))
{
return null;
}

var mapping = _typeMappingSource.FindMapping(dataTypeName);
if (mapping == null)
{
return null;
}

if (nullable)
Unwrap();
if (defaultValueSql.StartsWith("CONVERT", StringComparison.OrdinalIgnoreCase))
{
return defaultValue;
defaultValueSql = defaultValueSql.Substring(defaultValueSql.IndexOf(',') + 1);
defaultValueSql = defaultValueSql.Substring(0, defaultValueSql.LastIndexOf(')'));
Unwrap();
}

if (defaultValue == "((0))" || defaultValue == "(0)")
if (defaultValueSql.Equals("NULL", StringComparison.OrdinalIgnoreCase))
{
if (dataTypeName is "bigint" or "bit" or "decimal" or "float" or "int" or "money" or "numeric" or "real" or "smallint"
or "smallmoney" or "tinyint")
return null;
}

var type = mapping.ClrType;
if (type == typeof(bool)
&& int.TryParse(defaultValueSql, out var intValue))
{
return intValue != 0;
}

if (type.IsNumeric())
{
try
{
return Convert.ChangeType(defaultValueSql, type);
}
catch
{
// Ignored
return null;
}
}
else if (defaultValue == "((0.0))" || defaultValue == "(0.0)")

if ((defaultValueSql.StartsWith('\'') || defaultValueSql.StartsWith("N'", StringComparison.OrdinalIgnoreCase))
&& defaultValueSql.EndsWith('\''))
{
if (dataTypeName is "decimal" or "float" or "money" or "numeric" or "real" or "smallmoney")
var startIndex = defaultValueSql.IndexOf('\'');
defaultValueSql = defaultValueSql.Substring(startIndex + 1, defaultValueSql.Length - (startIndex + 2));

if (type == typeof(string))
{
return null;
return defaultValueSql;
}

if (type == typeof(bool)
&& bool.TryParse(defaultValueSql, out var boolValue))
{
return boolValue;
}

if (type == typeof(Guid)
&& Guid.TryParse(defaultValueSql, out var guid))
{
return guid;
}

if (type == typeof(DateTime)
&& DateTime.TryParse(defaultValueSql, out var dateTime))
{
return dateTime;
}

if (type == typeof(DateOnly)
&& DateOnly.TryParse(defaultValueSql, out var dateOnly))
{
return dateOnly;
}

if (type == typeof(TimeOnly)
&& TimeOnly.TryParse(defaultValueSql, out var timeOnly))
{
return timeOnly;
}

if (type == typeof(DateTimeOffset)
&& DateTimeOffset.TryParse(defaultValueSql, out var dateTimeOffset))
{
return dateTimeOffset;
}
}
else if ((defaultValue == "(CONVERT([real],(0)))" && dataTypeName == "real")
|| (defaultValue == "((0.0000000000000000e+000))" && dataTypeName == "float")
|| (defaultValue == "(0.0000000000000000e+000)" && dataTypeName == "float")
|| (defaultValue == "('0001-01-01')" && dataTypeName == "date")
|| (defaultValue == "('1900-01-01T00:00:00.000')" && (dataTypeName == "datetime" || dataTypeName == "smalldatetime"))
|| (defaultValue == "('0001-01-01T00:00:00.000')" && dataTypeName == "datetime2")
|| (defaultValue == "('0001-01-01T00:00:00.000+00:00')" && dataTypeName == "datetimeoffset")
|| (defaultValue == "('00:00:00')" && dataTypeName == "time")
|| (defaultValue == "('00000000-0000-0000-0000-000000000000')" && dataTypeName == "uniqueidentifier"))

return null;

void Unwrap()
{
return null;
while (defaultValueSql.StartsWith('(') && defaultValueSql.EndsWith(')'))
{
defaultValueSql = (defaultValueSql.Substring(1, defaultValueSql.Length - 2)).Trim();
}
}

return defaultValue;
}

private static string GetStoreType(string dataTypeName, int maxLength, int precision, int scale)
Expand Down Expand Up @@ -1190,11 +1255,11 @@ private void GetForeignKeys(DbConnection connection, IReadOnlyList<DatabaseTable
[t].[name] AS [table_name],
[f].[name],
SCHEMA_NAME(tab2.[schema_id]) AS [principal_table_schema],
[tab2].name AS [principal_table_name],
[tab2].name AS [principal_table_name],
[f].[delete_referential_action_desc],
[col1].[name] AS [column_name],
[col2].[name] AS [referenced_column_name]
FROM [sys].[foreign_keys] AS [f]
FROM [sys].[foreign_keys] AS [f]
JOIN [sys].[foreign_key_columns] AS fc ON [fc].[constraint_object_id] = [f].[object_id]
JOIN [sys].[tables] AS [t] ON [t].[object_id] = [fc].[parent_object_id]
JOIN [sys].[columns] AS [col1] ON [col1].[column_id] = [fc].[parent_column_id] AND [col1].[object_id] = [t].[object_id]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,62 @@ public Task ComputedColumnSql_works()
Assert.Equal("1 + 2", entity.GetProperty("ComputedColumn").GetComputedColumnSql());
});

[ConditionalFact]
public Task Column_with_default_value_only_uses_default_value()
=> TestAsync(
serviceProvider => serviceProvider.GetService<IScaffoldingModelFactory>().Create(
BuildModelWithColumn("nvarchar(max)", null, "Hot"), new ModelReverseEngineerOptions()),
new ModelCodeGenerationOptions(),
code => Assert.Contains($".HasDefaultValue(\"Hot\")", code.ContextFile.Code),
model =>
{
var property = model.FindEntityType("TestNamespace.Table")!.GetProperty("Column");
Assert.Equal("Hot", property.GetDefaultValue());
Assert.Null(property.FindAnnotation(RelationalAnnotationNames.DefaultValueSql));
});

[ConditionalFact]
public Task Column_with_default_value_sql_only_uses_default_value_sql()
=> TestAsync(
serviceProvider => serviceProvider.GetService<IScaffoldingModelFactory>().Create(
BuildModelWithColumn("nvarchar(max)", "('Hot')", null), new ModelReverseEngineerOptions()),
new ModelCodeGenerationOptions(),
code => Assert.Contains($".HasDefaultValueSql(\"('Hot')\")", code.ContextFile.Code),
model =>
{
var property = model.FindEntityType("TestNamespace.Table")!.GetProperty("Column");
Assert.Equal("('Hot')", property.GetDefaultValueSql());
Assert.Null(property.FindAnnotation(RelationalAnnotationNames.DefaultValue));
});

[ConditionalFact]
public Task Column_with_default_value_sql_and_default_value_uses_default_value()
=> TestAsync(
serviceProvider => serviceProvider.GetService<IScaffoldingModelFactory>().Create(
BuildModelWithColumn("nvarchar(max)", "('Hot')", "Hot"), new ModelReverseEngineerOptions()),
new ModelCodeGenerationOptions(),
code => Assert.Contains($".HasDefaultValue(\"Hot\")", code.ContextFile.Code),
model =>
{
var property = model.FindEntityType("TestNamespace.Table")!.GetProperty("Column");
Assert.Equal("Hot", property.GetDefaultValue());
Assert.Null(property.FindAnnotation(RelationalAnnotationNames.DefaultValueSql));
});

[ConditionalFact]
public Task Column_with_default_value_sql_and_default_value_where_value_is_CLR_default_uses_neither()
=> TestAsync(
serviceProvider => serviceProvider.GetService<IScaffoldingModelFactory>().Create(
BuildModelWithColumn("int", "((0))", 0), new ModelReverseEngineerOptions()),
new ModelCodeGenerationOptions(),
code => Assert.DoesNotContain("HasDefaultValue", code.ContextFile.Code),
model =>
{
var property = model.FindEntityType("TestNamespace.Table")!.GetProperty("Column");
Assert.Null(property.FindAnnotation(RelationalAnnotationNames.DefaultValue));
Assert.Null(property.FindAnnotation(RelationalAnnotationNames.DefaultValueSql));
});

[ConditionalFact]
public Task IsUnicode_works()
=> TestAsync(
Expand Down
Loading

0 comments on commit 101e88e

Please sign in to comment.