Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion eng/CodeGeneration.targets
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@
<Exec Command="npm ci --prefix $(_TspClientDir)" />
</Target>

<Target Name="GenerateCode" Condition="'$(TypeSpecInput)' != ''" DependsOnTargets="InstallTspClient">
<Target Name="BuildPlugin" Condition="'$(TypeSpecInput)' != '' AND '$(SkipBuildPlugin)' != 'true'">
<Exec Command="dotnet build $(MSBuildThisFileDirectory)packages/plugins/client/Client.Plugin/" />
</Target>

<Target Name="GenerateCode" Condition="'$(TypeSpecInput)' != ''" DependsOnTargets="InstallTspClient;BuildPlugin">
<Error Text="You used skipped sync but didn't have the TempTypeSpecFiles in your project directory. Please run 'dotnet build /t:GenerateCode /p:SaveInputs=true' without SkipSync first at least once" Condition="'$(SkipSync)' == 'true' AND !Exists('$(MSBuildProjectDirectory)/../TempTypeSpecFiles')" />
<Exec Condition="'$(SkipSync)' == 'true'" Command="$(_TypeSpecProjectGenerateCommand) $(_SaveInputs) $(_TypespecAdditionalOptions) $(_Trace)"/>
<Exec Condition="'$(SkipSync)' != 'true'" Command="$(_TypeSpecProjectSyncAndGenerateCommand) $(_SaveInputs) $(_LocalSpecRepo) $(_TypespecAdditionalOptions) $(_Trace)"/>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System.Collections.Generic;
using System.Linq;
using Microsoft.TypeSpec.Generator.ClientModel;
using Microsoft.TypeSpec.Generator.ClientModel.Providers;
using Microsoft.TypeSpec.Generator.Expressions;
using Microsoft.TypeSpec.Generator.Input;
using Microsoft.TypeSpec.Generator.Providers;
using Microsoft.TypeSpec.Generator.Statements;
using System.Collections.Generic;
using System.Linq;
using static Microsoft.TypeSpec.Generator.Snippets.Snippet;

namespace Azure.Generator.Visitors
Expand All @@ -17,8 +18,16 @@ namespace Azure.Generator.Visitors
/// </summary>
internal class SpecialHeadersVisitor : ScmLibraryVisitor
{
private readonly bool _includeClientRequestIdInRequest;
private const string ReturnClientRequestIdParameterName = "return-client-request-id";
private const string XMsClientRequestIdParameterName = "x-ms-client-request-id";
private readonly Dictionary<InputServiceMethod, (InputMethodParameter? ReturnClientRequestId, InputMethodParameter? XmsClientRequestId)> _serviceMethodParameterMap;

public SpecialHeadersVisitor(bool includeXmsClientRequestIdInRequest = false)
{
_includeClientRequestIdInRequest = includeXmsClientRequestIdInRequest;
_serviceMethodParameterMap = [];
}

protected override ScmMethodProviderCollection? Visit(
InputServiceMethod serviceMethod,
Expand All @@ -38,52 +47,76 @@ internal class SpecialHeadersVisitor : ScmLibraryVisitor
// Create a new method collection with the updated service method
methods = new ScmMethodProviderCollection(serviceMethod, client);

// Reset the rest client so that its methods are rebuilt.
// Store the parameters for the CreateRequest method
_serviceMethodParameterMap.TryAdd(serviceMethod, (returnClientRequestIdParameter, xMsClientRequestIdParameter));

// Reset the rest client so that its methods are rebuilt reflecting the new signatures
client.RestClient.Reset();
var createRequestMethod = client.RestClient.GetCreateRequestMethod(serviceMethod.Operation);
}

return methods;
}

protected override ScmMethodProvider? VisitMethod(ScmMethodProvider method)
{
if (method.ServiceMethod is null || !_serviceMethodParameterMap.TryGetValue(method.ServiceMethod, out var parameters))
{
return method;
}

var originalBodyStatements = createRequestMethod.BodyStatements!.ToList();
var (returnClientRequestIdParameter, xMsClientRequestIdParameter) = parameters;

// Exclude the last statement which is the return statement. We will add it back later.
var newStatements = new List<MethodBodyStatement>(originalBodyStatements[..^1]);
var originalBodyStatements = method.BodyStatements!.ToList();

// Find the request variable
VariableExpression? requestVariable = null;
foreach (var statement in newStatements)
// Exclude the last statement which is the return statement. We will add it back later.
var newStatements = new List<MethodBodyStatement>(originalBodyStatements[..^1]);

// Find the request variable
VariableExpression? requestVariable = null;
foreach (var statement in newStatements)
{
if (statement is ExpressionStatement
{
Expression: AssignmentExpression { Variable: DeclarationExpression declaration }
})
{
if (statement is ExpressionStatement
{
Expression: AssignmentExpression { Variable: DeclarationExpression declaration }
})
var variable = declaration.Variable;
if (variable.Type.Equals(variable.ToApi<HttpRequestApi>().Type))
{
var variable = declaration.Variable;
if (variable.Type.Equals(variable.ToApi<HttpRequestApi>().Type))
{
requestVariable = variable;
}
requestVariable = variable;
}
}
}

if (returnClientRequestIdParameter?.DefaultValue?.Value != null)
var request = requestVariable?.ToApi<HttpRequestApi>();
if (request != null && returnClientRequestIdParameter?.DefaultValue?.Value != null)
{
if (bool.TryParse(returnClientRequestIdParameter.DefaultValue.Value.ToString(), out bool value))
{
if (bool.TryParse(returnClientRequestIdParameter.DefaultValue.Value.ToString(), out bool value))
{
// Set the return-client-request-id header
newStatements.Add(requestVariable!.ToApi<HttpRequestApi>().SetHeaders(
[
Literal(returnClientRequestIdParameter.SerializedName),
Literal(value.ToString().ToLowerInvariant())
]));
}
// Set the return-client-request-id header
newStatements.Add(request.SetHeaders(
[
Literal(returnClientRequestIdParameter.SerializedName),
Literal(value.ToString().ToLowerInvariant())
]));
}
}

// Add the return statement back
newStatements.Add(originalBodyStatements[^1]);

createRequestMethod.Update(bodyStatements: newStatements);
if (request != null && _includeClientRequestIdInRequest && xMsClientRequestIdParameter != null)
{
// Set the x-ms-client-request-id header
newStatements.Add(request.SetHeaders(
[
Literal(xMsClientRequestIdParameter.SerializedName),
request.Property("ClientRequestId")
]));
}

return methods;
// Add the return statement back
newStatements.Add(originalBodyStatements[^1]);

method.Update(bodyStatements: newStatements);
return method;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ public static Mock<AzureClientGenerator> LoadMockGenerator(
Func<IReadOnlyList<InputModelType>>? inputModels = null,
Func<IReadOnlyList<InputClient>>? clients = null,
Func<InputClient, ClientProvider?>? createClientCore = null,
Func<IReadOnlyList<ScmLibraryVisitor>>? visitors = null,
ClientResponseApi? clientResponseApi = null,
ClientPipelineApi? clientPipelineApi = null,
HttpMessageApi? httpMessageApi = null,
Expand Down Expand Up @@ -91,11 +92,24 @@ public static Mock<AzureClientGenerator> LoadMockGenerator(

var sourceInputModel = new Mock<SourceInputModel>(() => new SourceInputModel(null, null)) { CallBase = true };
mockPluginInstance.Setup(p => p.SourceInputModel).Returns(sourceInputModel.Object);
var configureMethod = typeof(CodeModelGenerator).GetMethod(
"Configure",
BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.InvokeMethod
);
configureMethod!.Invoke(mockPluginInstance.Object, null);

if (visitors != null)
{
var visitorsList = visitors.Invoke();
foreach (var visitor in visitorsList)
{
mockPluginInstance.Object.AddVisitor(visitor);
}
}
else
{
var configureMethod = typeof(CodeModelGenerator).GetMethod(
"Configure",
BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.InvokeMethod
);
configureMethod!.Invoke(mockPluginInstance.Object, null);
}

return mockPluginInstance;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,43 +8,72 @@
using Azure.Generator.Tests.Common;
using Azure.Generator.Tests.TestHelpers;
using Azure.Generator.Visitors;
using Microsoft.TypeSpec.Generator.ClientModel.Providers;

namespace Azure.Generator.Tests.Visitors
{
public class SpecialHeadersVisitorTests
{
[Test]
public void RemovesSpecialHeaderParametersFromServiceMethod()
[TestCase(true)]
[TestCase(false)]
public void RemovesSpecialHeaderParametersFromServiceMethod(bool addBackXMsClientRequestId)
{
var visitor = new TestSpecialHeadersVisitor();
var visitor = new TestSpecialHeadersVisitor(addBackXMsClientRequestId);
var parameters = CreateHttpParameters();
var methodParameters = CreateMethodParameters();
var responseModel = InputFactory.Model("foo");
var operation = InputFactory.Operation(
"foo",
parameters: parameters,
responses: [InputFactory.OperationResponse(bodytype: responseModel)]);
var serviceMethod = InputFactory.LongRunningServiceMethod(
"foo",
operation,
parameters: methodParameters,
response: InputFactory.ServiceMethodResponse(responseModel, ["result"]));
var inputClient = InputFactory.Client("TestClient", methods: [serviceMethod]);
MockHelpers.LoadMockGenerator(clients: () => [inputClient]);
var serviceMethods = new List<InputServiceMethod>();

// Verify initial parameters include special headers
Assert.AreEqual(3, serviceMethod.Parameters.Count);
Assert.IsTrue(serviceMethod.Parameters.Any(p => p.SerializedName == "return-client-request-id"));
Assert.IsTrue(serviceMethod.Parameters.Any(p => p.SerializedName == "x-ms-client-request-id"));
// Create two operations and service methods
for (int i = 1; i <= 2; i++)
{
var operation = InputFactory.Operation(
$"operation{i}",
parameters: parameters,
responses: [InputFactory.OperationResponse(bodytype: responseModel)]);
var serviceMethod = InputFactory.LongRunningServiceMethod(
$"operation{i}",
operation,
parameters: methodParameters,
response: InputFactory.ServiceMethodResponse(responseModel, ["result"]));
serviceMethods.Add(serviceMethod);
}

// Act - this would normally be called by the visitor framework, but we'll test the core logic
visitor.InvokeRemoveSpecialHeaders(serviceMethod);
var inputClient = InputFactory.Client("TestClient", methods: serviceMethods);

// Verify special headers are removed
Assert.AreEqual(1, serviceMethod.Parameters.Count);
Assert.IsFalse(serviceMethod.Parameters.Any(p => p.SerializedName == "return-client-request-id"));
Assert.IsFalse(serviceMethod.Parameters.Any(p => p.SerializedName == "x-ms-client-request-id"));
Assert.IsTrue(serviceMethod.Parameters.Any(p => p.SerializedName == "some-other-parameter"));
// Verify initial parameters include special headers for both methods
foreach (var serviceMethod in serviceMethods)
{
Assert.AreEqual(3, serviceMethod.Parameters.Count);
Assert.IsTrue(serviceMethod.Parameters.Any(p => p.SerializedName == "return-client-request-id"));
Assert.IsTrue(serviceMethod.Parameters.Any(p => p.SerializedName == "x-ms-client-request-id"));
}

var generator = MockHelpers.LoadMockGenerator(
clients: () => [inputClient],
visitors: () => [visitor]);
var client = generator.Object.OutputLibrary.TypeProviders.OfType<ClientProvider>().First();

// Verify special headers are removed from both methods
foreach (var serviceMethod in serviceMethods)
{
var methodCollection = client.GetMethodCollectionByOperation(serviceMethod.Operation);
visitor.InvokePreVisit(serviceMethod, client, methodCollection);
}

foreach (var serviceMethod in serviceMethods)
{
var restClientMethod = client.RestClient.Methods.First(m => m.Signature.Name == $"Create{serviceMethod.Name}Request");

visitor.InvokeVisit((restClientMethod as ScmMethodProvider)!);
Assert.AreEqual(1, serviceMethod.Parameters.Count);
Assert.IsFalse(serviceMethod.Parameters.Any(p => p.SerializedName == "return-client-request-id"));
Assert.IsFalse(serviceMethod.Parameters.Any(p => p.SerializedName == "x-ms-client-request-id"));
Assert.IsTrue(serviceMethod.Parameters.Any(p => p.SerializedName == "some-other-parameter"));

// Verify x-ms-client-request-id is added back in method body if specified
Assert.AreEqual(addBackXMsClientRequestId, restClientMethod.BodyStatements!.ToDisplayString().Contains("request.Headers.SetValue(\"x-ms-client-request-id\", request.ClientRequestId);"));
}
}

[Test]
Expand Down Expand Up @@ -76,13 +105,17 @@ public void DoesNotChangeParametersWhenNoSpecialHeaders()
operation,
parameters: methodParameters,
response: InputFactory.ServiceMethodResponse(responseModel, ["result"]));
var inputClient = InputFactory.Client("TestClient", methods: [serviceMethod]);
var generator = MockHelpers.LoadMockGenerator(clients: () => [inputClient]);
var client = generator.Object.OutputLibrary.TypeProviders.OfType<ClientProvider>().First();
var methodCollection = client.GetMethodCollectionByOperation(operation);

// Verify initial state
Assert.AreEqual(1, serviceMethod.Parameters.Count);
var originalParameter = serviceMethod.Parameters[0];

// Act
visitor.InvokeRemoveSpecialHeaders(serviceMethod);
visitor.InvokePreVisit(serviceMethod, client, methodCollection);

// Verify no changes
Assert.AreEqual(1, serviceMethod.Parameters.Count);
Expand Down Expand Up @@ -136,19 +169,18 @@ private static List<InputParameter> CreateHttpParameters()

private class TestSpecialHeadersVisitor : SpecialHeadersVisitor
{
public void InvokeRemoveSpecialHeaders(InputServiceMethod serviceMethod)
public TestSpecialHeadersVisitor(bool addBackXMsClientRequestId = false)
: base(addBackXMsClientRequestId)
{
}
public void InvokePreVisit(InputServiceMethod serviceMethod, ClientProvider client, ScmMethodProviderCollection methods)
{
base.Visit(serviceMethod, client, methods);
}

public void InvokeVisit(ScmMethodProvider method)
{
// Simulate the core logic of removing special headers
var returnClientRequestIdParameter =
serviceMethod.Parameters.FirstOrDefault(p => p.SerializedName == "return-client-request-id");
var xMsClientRequestIdParameter =
serviceMethod.Parameters.FirstOrDefault(p => p.SerializedName == "x-ms-client-request-id");

if (returnClientRequestIdParameter != null || xMsClientRequestIdParameter != null)
{
serviceMethod.Update(parameters: [.. serviceMethod.Parameters.Where(p => p.SerializedName != "return-client-request-id" && p.SerializedName != "x-ms-client-request-id")]);
serviceMethod.Operation.Update(parameters: [.. serviceMethod.Operation.Parameters.Where(p => p.SerializedName != "return-client-request-id" && p.SerializedName != "x-ms-client-request-id")]);
}
base.VisitMethod(method);
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions eng/packages/plugins/client/Client.Plugin/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ trigger:
paths:
include:
- eng/packages/plugins/client/Client.Plugin
- eng/packages/http-client-csharp
pr:
branches:
include:
Expand All @@ -15,6 +16,7 @@ pr:
paths:
include:
- eng/packages/plugins/client/Client.Plugin
- eng/packages/http-client-csharp

variables:
- template: /eng/pipelines/templates/variables/image.yml
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@

<!-- Shared visitor source files -->
<PropertyGroup>
<VisitorsSharedSource>$(MSBuildThisFileDirectory)..\..\..\..\http-client-csharp\generator\Azure.Generator\src\Visitors</VisitorsSharedSource>
<AzureGeneratorSource>$(MSBuildThisFileDirectory)..\..\..\..\http-client-csharp\generator\Azure.Generator\src\</AzureGeneratorSource>
</PropertyGroup>

<ItemGroup>
<Compile Include="$(VisitorsSharedSource)\NamespaceVisitor.cs" LinkBase="Shared/Visitors" />
<Compile Include="$(VisitorsSharedSource)\SpecialHeadersVisitor.cs" LinkBase="Shared/Visitors" />
<Compile Include="$(VisitorsSharedSource)\ModelFactoryRenamerVisitor.cs" LinkBase="Shared/Visitors" />
<Compile Include="$(VisitorsSharedSource)\Extensions\ConfigurationExtensions.cs" LinkBase="Shared/Visitors/Extensions" />
<Compile Include="$(VisitorsSharedSource)\Utilities\TypeNameUtilities.cs" LinkBase="Shared/Visitors/Utilities" />
<Compile Include="$(AzureGeneratorSource)\Visitors\NamespaceVisitor.cs" LinkBase="Shared/" />
<Compile Include="$(AzureGeneratorSource)\Visitors\SpecialHeadersVisitor.cs" LinkBase="Shared/" />
<Compile Include="$(AzureGeneratorSource)\Visitors\ModelFactoryRenamerVisitor.cs" LinkBase="Shared/" />
<Compile Include="$(AzureGeneratorSource)\Extensions\ConfigurationExtensions.cs" LinkBase="Shared/" />
<Compile Include="$(AzureGeneratorSource)\Utilities\TypeNameUtilities.cs" LinkBase="Shared/" />
</ItemGroup>

<!-- Copy output to package dist path -->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ public override void Apply(CodeModelGenerator generator)

// Rest of the visitors can be added in any order.
generator.AddVisitor(new NamespaceVisitor());
generator.AddVisitor(new SpecialHeadersVisitor());
generator.AddVisitor(new SpecialHeadersVisitor(includeXmsClientRequestIdInRequest: true));
}
}
}
Loading
Loading