Skip to content

Commit

Permalink
Merge pull request #4 from microsoft/alzollin/generatorImprovements
Browse files Browse the repository at this point in the history
Simplified generated projects.
  • Loading branch information
nmetulev authored Nov 18, 2024
2 parents ee8c64a + 1c3da64 commit 4a2421d
Show file tree
Hide file tree
Showing 8 changed files with 45 additions and 64 deletions.
74 changes: 40 additions & 34 deletions AIDevGallery/ProjectGenerator/Generator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -222,26 +222,8 @@ private async Task<string> GenerateAsyncInternal(Sample sample, Dictionary<Model
};

string modelTemplateString = GetPromptTemplateString(modelInfos.Values.Select(m => m.ModelPromptTemplate).ToList());
string modelPathString;
string hardwareAcceleratorString;
string sampleNavigationParameterName;
bool isMultiModel;
if (modelInfos.Count > 1)
{
modelPathString = "[" + Environment.NewLine + string.Join($",{Environment.NewLine}", modelInfos.Values.Select(m => " " + m.ModelPathStr)) + Environment.NewLine + " ]";
isMultiModel = true;
hardwareAcceleratorString = "[" + string.Join(", ", modelInfos.Values.Select(m => $"HardwareAccelerator.{m.HardwareAccelerator}")) + "]";
sampleNavigationParameterName = "MultiModelSampleNavigationParameters";
}
else
{
modelPathString = modelInfos.Values.First().ModelPathStr;
isMultiModel = false;
hardwareAcceleratorString = $"HardwareAccelerator.{modelInfos.Values.First().HardwareAccelerator}";
sampleNavigationParameterName = "SampleNavigationParameters";
}

var className = await AddFilesFromSampleAsync(sample, packageReferences, safeProjectName, outputPath, addLllmTypes, isMultiModel, cancellationToken);
var className = await AddFilesFromSampleAsync(sample, packageReferences, safeProjectName, outputPath, addLllmTypes, modelInfos, cancellationToken);

foreach (var file in files)
{
Expand All @@ -253,12 +235,6 @@ private async Task<string> GenerateAsyncInternal(Sample sample, Dictionary<Model
relativePath = relativePath.Replace(fileName, newName);
}

if ((fileName == "MultiModelSampleNavigationParameters.cs" && modelInfos.Count == 1) ||
(fileName == "SampleNavigationParameters.cs" && modelInfos.Count > 1))
{
continue;
}

var outputPathFile = Path.Join(outputPath, relativePath);

// Create the directory if it doesn't exist
Expand Down Expand Up @@ -287,10 +263,7 @@ private async Task<string> GenerateAsyncInternal(Sample sample, Dictionary<Model
content = content.Replace("$XmlEscapedPublisher$", xmlEscapedPublisher);
content = content.Replace("$DotNetVersion$", DotNetVersion);
content = content.Replace("$MainSamplePage$", className);
content = content.Replace("$modelPath$", modelPathString);
content = content.Replace("$modelHardwareAccelerator$", hardwareAcceleratorString);
content = content.Replace("$promptTemplate$", modelTemplateString);
content = content.Replace("$sampleNavigationParameterName$", sampleNavigationParameterName);

// Write the file
await File.WriteAllTextAsync(outputPathFile, content, cancellationToken);
Expand Down Expand Up @@ -386,7 +359,7 @@ private async Task<string> GenerateAsyncInternal(Sample sample, Dictionary<Model
return outputPath;
}

private string GetChatClientLoaderString(Sample sample, bool isMultiModel)
private string GetChatClientLoaderString(Sample sample, bool isMultiModel, string modelPath)
{
if (!sample.SharedCode.Contains(SharedCodeEnum.GenAIModel))
{
Expand All @@ -395,11 +368,11 @@ private string GetChatClientLoaderString(Sample sample, bool isMultiModel)

if (isMultiModel)
{
return "GenAIModel.CreateAsync(sampleParams.ModelPaths[0], sampleParams.PromptTemplates[0], sampleParams.CancellationToken)";
return $"GenAIModel.CreateAsync({modelPath}, sampleParams.PromptTemplates[0], System.Threading.CancellationToken.None)";
}
else
{
return "GenAIModel.CreateAsync(sampleParams.ModelPath, sampleParams.PromptTemplate, sampleParams.CancellationToken)";
return $"GenAIModel.CreateAsync({modelPath}, sampleParams.PromptTemplate, System.Threading.CancellationToken.None)";
}
}

Expand Down Expand Up @@ -511,7 +484,14 @@ private string GetPromptTemplateString(List<PromptTemplate?> promptTemplates)
return modelPromptTemplateSb.ToString();
}

private async Task<string> AddFilesFromSampleAsync(Sample sample, List<(string PackageName, string? Version)> packageReferences, string safeProjectName, string outputPath, bool addLllmTypes, bool isMultiModel, CancellationToken cancellationToken)
private async Task<string> AddFilesFromSampleAsync(
Sample sample,
List<(string PackageName, string? Version)> packageReferences,
string safeProjectName,
string outputPath,
bool addLllmTypes,
Dictionary<ModelType, (string CachedModelDirectoryPath, string ModelUrl, bool IsSingleFile, string ModelPathStr, HardwareAccelerator HardwareAccelerator, PromptTemplate? ModelPromptTemplate)> modelInfos,
CancellationToken cancellationToken)
{
var sharedCode = sample.SharedCode.ToList();
if (!sharedCode.Contains(SharedCodeEnum.LlmPromptTemplate) &&
Expand Down Expand Up @@ -557,11 +537,37 @@ private async Task<string> AddFilesFromSampleAsync(Sample sample, List<(string P
if (!string.IsNullOrEmpty(sample.CSCode))
{
var cleanCsSource = CleanCsSource(sample.CSCode, safeProjectName, true);
var chatClientLoader = GetChatClientLoaderString(sample, isMultiModel);
cleanCsSource = cleanCsSource.Replace("sampleParams.NotifyCompletion();", "App.Window?.ModelLoaded();");

string modelPath;
if (modelInfos.Count > 1)
{
cleanCsSource = cleanCsSource.Replace("MultiModelSampleNavigationParameters", "SampleNavigationParameters");

int i = 0;
foreach (var modelInfo in modelInfos)
{
cleanCsSource = cleanCsSource.Replace($"sampleParams.HardwareAccelerators[{i}]", $"HardwareAccelerator.{modelInfo.Value.HardwareAccelerator}");
cleanCsSource = cleanCsSource.Replace($"sampleParams.ModelPaths[{i}]", modelInfo.Value.ModelPathStr);
i++;
}

modelPath = modelInfos.First().Value.ModelPathStr;
}
else
{
var modelInfo = modelInfos.Values.First();
cleanCsSource = cleanCsSource.Replace("sampleParams.HardwareAccelerator", $"HardwareAccelerator.{modelInfo.HardwareAccelerator}");
cleanCsSource = cleanCsSource.Replace("sampleParams.ModelPath", modelInfo.ModelPathStr);
modelPath = modelInfo.ModelPathStr;
}

cleanCsSource = cleanCsSource.Replace("sampleParams.CancellationToken", "CancellationToken.None");

var chatClientLoader = GetChatClientLoaderString(sample, modelInfos.Count > 1, modelPath);
if (chatClientLoader != null)
{
cleanCsSource = cleanCsSource.Replace("sampleParams.GetIChatClientAsync()", chatClientLoader);
cleanCsSource = cleanCsSource.Replace("sampleParams.NotifyCompletion();", "App.Window?.ModelLoaded();");
}

await File.WriteAllTextAsync(Path.Join(outputPath, $"{className}.xaml.cs"), cleanCsSource, cancellationToken);
Expand Down
6 changes: 1 addition & 5 deletions AIDevGallery/ProjectGenerator/Template/MainWindow.xaml.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,7 @@ public MainWindow()
this.InitializeComponent();
this.RootFrame.Loaded += (sender, args) =>
{
var sampleLoadingCts = new CancellationTokenSource();

var localModelDetails = new $sampleNavigationParameterName$(sampleLoadingCts.Token);

RootFrame.Navigate(typeof($MainSamplePage$), localModelDetails);
RootFrame.Navigate(typeof($MainSamplePage$), new SampleNavigationParameters());
};
}

Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
<RootNamespace>$safeprojectname$</RootNamespace>
<ApplicationManifest>app.manifest</ApplicationManifest>
<Platforms>x86;x64;ARM64</Platforms>
<RuntimeIdentifiers Condition="$([MSBuild]::GetTargetFrameworkVersion('$(TargetFramework)')) >= 8">win-x86;win-x64;win-arm64</RuntimeIdentifiers>
<RuntimeIdentifiers Condition="$([MSBuild]::GetTargetFrameworkVersion('$(TargetFramework)')) &lt; 8">win10-x86;win10-x64;win10-arm64</RuntimeIdentifiers>
<RuntimeIdentifiers>win-x86;win-x64;win-arm64</RuntimeIdentifiers>
<PublishProfile>win-$(Platform).pubxml</PublishProfile>
<UseWinUI>true</UseWinUI>
<EnableMsixTooling>true</EnableMsixTooling>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ https://go.microsoft.com/fwlink/?LinkID=208121.
<PropertyGroup>
<PublishProtocol>FileSystem</PublishProtocol>
<Platform>ARM64</Platform>
<RuntimeIdentifier>win-arm64</RuntimeIdentifier>
<PublishDir>bin\$(Configuration)\$(TargetFramework)\$(RuntimeIdentifier)\publish\</PublishDir>
<SelfContained>true</SelfContained>
<PublishSingleFile>False</PublishSingleFile>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ https://go.microsoft.com/fwlink/?LinkID=208121.
<PropertyGroup>
<PublishProtocol>FileSystem</PublishProtocol>
<Platform>x64</Platform>
<RuntimeIdentifier>win-x64</RuntimeIdentifier>
<PublishDir>bin\$(Configuration)\$(TargetFramework)\$(RuntimeIdentifier)\publish\</PublishDir>
<SelfContained>true</SelfContained>
<PublishSingleFile>False</PublishSingleFile>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ https://go.microsoft.com/fwlink/?LinkID=208121.
<PropertyGroup>
<PublishProtocol>FileSystem</PublishProtocol>
<Platform>x86</Platform>
<RuntimeIdentifier>win-x86</RuntimeIdentifier>
<PublishDir>bin\$(Configuration)\$(TargetFramework)\$(RuntimeIdentifier)\publish\</PublishDir>
<SelfContained>true</SelfContained>
<PublishSingleFile>False</PublishSingleFile>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,6 @@ namespace $safeprojectname$.SharedCode
{
internal class SampleNavigationParameters
{
public CancellationToken CancellationToken { get; private set; }
public string ModelPath => $modelPath$;
public HardwareAccelerator HardwareAccelerator => $modelHardwareAccelerator$;
$promptTemplate$
public SampleNavigationParameters(CancellationToken loadingCanceledToken)
{
CancellationToken = loadingCanceledToken;
}
}
}

0 comments on commit 4a2421d

Please sign in to comment.