Skip to content

Commit

Permalink
Proof of concept for #32385
Browse files Browse the repository at this point in the history
  • Loading branch information
ajcvickers committed Dec 2, 2023
1 parent 32c6133 commit 995ac07
Show file tree
Hide file tree
Showing 5 changed files with 345 additions and 157 deletions.
2 changes: 1 addition & 1 deletion src/EFCore.Design/EFCore.Design.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
<PackageReference Include="Microsoft.CodeAnalysis.CSharp.Workspaces" Version="$(MicrosoftCodeAnalysisVersion)" />
<PackageReference Include="Microsoft.Extensions.DependencyModel" Version="$(MicrosoftExtensionsDependencyModelVersion)" />
<PackageReference Include="Microsoft.Extensions.HostFactoryResolver.Sources" PrivateAssets="All" Version="$(MicrosoftExtensionsHostFactoryResolverSourcesVersion)" />
<PackageReference Include="Mono.TextTemplating" Version="2.2.1" />
<PackageReference Include="Mono.TextTemplating" Version="2.3.1" />
</ItemGroup>

<ItemGroup>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System.CodeDom.Compiler;
using System.Globalization;
using System.Text;
using Microsoft.EntityFrameworkCore.Design.Internal;
using Microsoft.EntityFrameworkCore.Internal;
using Microsoft.VisualStudio.TextTemplating;
using Mono.TextTemplating;

namespace Microsoft.EntityFrameworkCore.Scaffolding.Internal;
Expand Down Expand Up @@ -102,12 +104,18 @@ public override ScaffoldedModel GenerateModel(IModel model, ModelCodeGenerationO
};
var contextTemplate = Path.Combine(options.ProjectDir!, TemplatesDirectory, DbContextTemplate);

string generatedCode;
string? generatedCode = null;
if (File.Exists(contextTemplate))
{
host.TemplateFile = contextTemplate;

generatedCode = Engine.ProcessTemplate(File.ReadAllText(contextTemplate), host);
var compiledTemplate = Engine.CompileTemplateAsync(File.ReadAllText(contextTemplate), host, new()).GetAwaiter().GetResult();

if (compiledTemplate != null)
{
generatedCode = ProcessTemplate(compiledTemplate, host);
}

CheckEncoding(host.OutputEncoding);
HandleErrors(host);
}
Expand Down Expand Up @@ -142,7 +150,7 @@ public override ScaffoldedModel GenerateModel(IModel model, ModelCodeGenerationO
Path = options.ContextDir != null
? Path.Combine(options.ContextDir, dbContextFileName)
: dbContextFileName,
Code = generatedCode
Code = generatedCode!
}
};

Expand All @@ -165,12 +173,16 @@ public override ScaffoldedModel GenerateModel(IModel model, ModelCodeGenerationO

if (compiledEntityTypeTemplate is null)
{
compiledEntityTypeTemplate = Engine.CompileTemplate(File.ReadAllText(entityTypeTemplate), host);
compiledEntityTypeTemplate = Engine.CompileTemplateAsync(File.ReadAllText(entityTypeTemplate), host, new()).GetAwaiter().GetResult();;
entityTypeExtension = host.Extension;
CheckEncoding(host.OutputEncoding);
}

generatedCode = compiledEntityTypeTemplate.Process();
if (compiledEntityTypeTemplate != null)
{
generatedCode = ProcessTemplate(compiledEntityTypeTemplate, host);
}

HandleErrors(host);

if (string.IsNullOrWhiteSpace(generatedCode))
Expand Down Expand Up @@ -208,12 +220,16 @@ public override ScaffoldedModel GenerateModel(IModel model, ModelCodeGenerationO

if (compiledConfigurationTemplate is null)
{
compiledConfigurationTemplate = Engine.CompileTemplate(File.ReadAllText(configurationTemplate), host);
compiledConfigurationTemplate = Engine.CompileTemplateAsync(File.ReadAllText(configurationTemplate), host, new()).GetAwaiter().GetResult();;
configurationExtension = host.Extension;
CheckEncoding(host.OutputEncoding);
}

generatedCode = compiledConfigurationTemplate.Process();
if (compiledConfigurationTemplate != null)
{
generatedCode = ProcessTemplate(compiledConfigurationTemplate, host);
}

HandleErrors(host);

if (string.IsNullOrWhiteSpace(generatedCode))
Expand Down Expand Up @@ -241,6 +257,61 @@ public override ScaffoldedModel GenerateModel(IModel model, ModelCodeGenerationO
return resultingFiles;
}

private static string ProcessTemplate(CompiledTemplate compiledTemplate, TextTemplatingEngineHost host)
{
var templateAssemblyData = GetField(compiledTemplate, "templateAssemblyData")!;
var templateClassFullName = (string)GetField(compiledTemplate, "templateClassFullName")!;
var culture = GetField(compiledTemplate, "culture");
var assemblyBytes = (byte[])templateAssemblyData.GetType().GetProperty("Assembly")!.GetValue(templateAssemblyData)!;

var assembly = Assembly.Load(assemblyBytes);
var transformType = assembly.GetType(templateClassFullName)!;
var textTransformation = Activator.CreateInstance(transformType);

var hostProp = transformType.GetProperty("Host", typeof(ITextTemplatingEngineHost));
if (hostProp != null)
{
hostProp.SetValue(textTransformation, host, null);
}

var sessionProp = transformType.GetProperty("Session", typeof(IDictionary<string, object>));
if (sessionProp != null)
{
sessionProp.SetValue(textTransformation, host.Session, null);
}

var errorProp = transformType.GetProperty("Errors", BindingFlags.Instance | BindingFlags.NonPublic)!;
var errorMethod = transformType.GetMethod("Error", new[] { typeof(string) })!;

var errors = (CompilerErrorCollection)errorProp.GetValue(textTransformation, null)!;
errors.Clear();

ToStringHelper.FormatProvider = culture != null ? (IFormatProvider)culture : CultureInfo.InvariantCulture;

string? output = null;

var initMethod = transformType.GetMethod("Initialize")!;
var transformMethod = transformType.GetMethod("TransformText")!;

try
{
initMethod.Invoke(textTransformation, null);
output = (string?)transformMethod.Invoke(textTransformation, null);
}
catch (Exception ex)
{
errorMethod.Invoke(textTransformation, new object[] { "Error running transform: " + ex });
}

host.LogErrors(errors);
return output!;

static object? GetField(CompiledTemplate compiledTemplate, string fieldName)
=> compiledTemplate.GetType()
.GetField(fieldName, BindingFlags.Instance | BindingFlags.NonPublic)!
.GetValue(compiledTemplate);
}

private void CheckEncoding(Encoding outputEncoding)
{
if (outputEncoding != Encoding.UTF8)
Expand Down
Loading

0 comments on commit 995ac07

Please sign in to comment.