From c37a2f37f683e1c28cc09f67415bb26c805c6d6c Mon Sep 17 00:00:00 2001 From: jolov Date: Mon, 3 Nov 2025 15:39:31 -0800 Subject: [PATCH] Just plugin changes --- eng/CodeGeneration.targets | 6 +- .../src/Visitors/SpecialHeadersVisitor.cs | 101 ++++++++++------ .../test/TestHelpers/MockHelpers.cs | 24 +++- .../Visitors/SpecialHeadersVisitorTests.cs | 108 ++++++++++++------ .../plugins/client/Client.Plugin/ci.yml | 2 + .../Client.Plugin/src/Client.Plugin.csproj | 12 +- .../client/Client.Plugin/src/ClientPlugin.cs | 2 +- eng/service.proj | 16 ++- 8 files changed, 180 insertions(+), 91 deletions(-) diff --git a/eng/CodeGeneration.targets b/eng/CodeGeneration.targets index 78883666bd68..289f789d2c6a 100644 --- a/eng/CodeGeneration.targets +++ b/eng/CodeGeneration.targets @@ -21,7 +21,11 @@ - + + + + + diff --git a/eng/packages/http-client-csharp/generator/Azure.Generator/src/Visitors/SpecialHeadersVisitor.cs b/eng/packages/http-client-csharp/generator/Azure.Generator/src/Visitors/SpecialHeadersVisitor.cs index 28aaa25c523f..c467a39a7fc3 100644 --- a/eng/packages/http-client-csharp/generator/Azure.Generator/src/Visitors/SpecialHeadersVisitor.cs +++ b/eng/packages/http-client-csharp/generator/Azure.Generator/src/Visitors/SpecialHeadersVisitor.cs @@ -1,13 +1,14 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -using System.Collections.Generic; -using System.Linq; using Microsoft.TypeSpec.Generator.ClientModel; using Microsoft.TypeSpec.Generator.ClientModel.Providers; using Microsoft.TypeSpec.Generator.Expressions; using Microsoft.TypeSpec.Generator.Input; +using Microsoft.TypeSpec.Generator.Providers; using Microsoft.TypeSpec.Generator.Statements; +using System.Collections.Generic; +using System.Linq; using static Microsoft.TypeSpec.Generator.Snippets.Snippet; namespace Azure.Generator.Visitors @@ -17,8 +18,16 @@ namespace Azure.Generator.Visitors /// internal class SpecialHeadersVisitor : ScmLibraryVisitor { + private readonly bool _includeClientRequestIdInRequest; private const string ReturnClientRequestIdParameterName = "return-client-request-id"; private const string XMsClientRequestIdParameterName = "x-ms-client-request-id"; + private readonly Dictionary _serviceMethodParameterMap; + + public SpecialHeadersVisitor(bool includeXmsClientRequestIdInRequest = false) + { + _includeClientRequestIdInRequest = includeXmsClientRequestIdInRequest; + _serviceMethodParameterMap = []; + } protected override ScmMethodProviderCollection? Visit( InputServiceMethod serviceMethod, @@ -38,52 +47,76 @@ internal class SpecialHeadersVisitor : ScmLibraryVisitor // Create a new method collection with the updated service method methods = new ScmMethodProviderCollection(serviceMethod, client); - // Reset the rest client so that its methods are rebuilt. + // Store the parameters for the CreateRequest method + _serviceMethodParameterMap.TryAdd(serviceMethod, (returnClientRequestIdParameter, xMsClientRequestIdParameter)); + + // Reset the rest client so that its methods are rebuilt reflecting the new signatures client.RestClient.Reset(); - var createRequestMethod = client.RestClient.GetCreateRequestMethod(serviceMethod.Operation); + } + + return methods; + } + + protected override ScmMethodProvider? VisitMethod(ScmMethodProvider method) + { + if (method.ServiceMethod is null || !_serviceMethodParameterMap.TryGetValue(method.ServiceMethod, out var parameters)) + { + return method; + } - var originalBodyStatements = createRequestMethod.BodyStatements!.ToList(); + var (returnClientRequestIdParameter, xMsClientRequestIdParameter) = parameters; - // Exclude the last statement which is the return statement. We will add it back later. - var newStatements = new List(originalBodyStatements[..^1]); + var originalBodyStatements = method.BodyStatements!.ToList(); - // Find the request variable - VariableExpression? requestVariable = null; - foreach (var statement in newStatements) + // Exclude the last statement which is the return statement. We will add it back later. + var newStatements = new List(originalBodyStatements[..^1]); + + // Find the request variable + VariableExpression? requestVariable = null; + foreach (var statement in newStatements) + { + if (statement is ExpressionStatement + { + Expression: AssignmentExpression { Variable: DeclarationExpression declaration } + }) { - if (statement is ExpressionStatement - { - Expression: AssignmentExpression { Variable: DeclarationExpression declaration } - }) + var variable = declaration.Variable; + if (variable.Type.Equals(variable.ToApi().Type)) { - var variable = declaration.Variable; - if (variable.Type.Equals(variable.ToApi().Type)) - { - requestVariable = variable; - } + requestVariable = variable; } } + } - if (returnClientRequestIdParameter?.DefaultValue?.Value != null) + var request = requestVariable?.ToApi(); + if (request != null && returnClientRequestIdParameter?.DefaultValue?.Value != null) + { + if (bool.TryParse(returnClientRequestIdParameter.DefaultValue.Value.ToString(), out bool value)) { - if (bool.TryParse(returnClientRequestIdParameter.DefaultValue.Value.ToString(), out bool value)) - { - // Set the return-client-request-id header - newStatements.Add(requestVariable!.ToApi().SetHeaders( - [ - Literal(returnClientRequestIdParameter.SerializedName), - Literal(value.ToString().ToLowerInvariant()) - ])); - } + // Set the return-client-request-id header + newStatements.Add(request.SetHeaders( + [ + Literal(returnClientRequestIdParameter.SerializedName), + Literal(value.ToString().ToLowerInvariant()) + ])); } + } - // Add the return statement back - newStatements.Add(originalBodyStatements[^1]); - - createRequestMethod.Update(bodyStatements: newStatements); + if (request != null && _includeClientRequestIdInRequest && xMsClientRequestIdParameter != null) + { + // Set the x-ms-client-request-id header + newStatements.Add(request.SetHeaders( + [ + Literal(xMsClientRequestIdParameter.SerializedName), + request.Property("ClientRequestId") + ])); } - return methods; + // Add the return statement back + newStatements.Add(originalBodyStatements[^1]); + + method.Update(bodyStatements: newStatements); + return method; } } } \ No newline at end of file diff --git a/eng/packages/http-client-csharp/generator/Azure.Generator/test/TestHelpers/MockHelpers.cs b/eng/packages/http-client-csharp/generator/Azure.Generator/test/TestHelpers/MockHelpers.cs index 61157a48b111..acfb8a02668f 100644 --- a/eng/packages/http-client-csharp/generator/Azure.Generator/test/TestHelpers/MockHelpers.cs +++ b/eng/packages/http-client-csharp/generator/Azure.Generator/test/TestHelpers/MockHelpers.cs @@ -33,6 +33,7 @@ public static Mock LoadMockGenerator( Func>? inputModels = null, Func>? clients = null, Func? createClientCore = null, + Func>? visitors = null, ClientResponseApi? clientResponseApi = null, ClientPipelineApi? clientPipelineApi = null, HttpMessageApi? httpMessageApi = null, @@ -91,11 +92,24 @@ public static Mock LoadMockGenerator( var sourceInputModel = new Mock(() => new SourceInputModel(null, null)) { CallBase = true }; mockPluginInstance.Setup(p => p.SourceInputModel).Returns(sourceInputModel.Object); - var configureMethod = typeof(CodeModelGenerator).GetMethod( - "Configure", - BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.InvokeMethod - ); - configureMethod!.Invoke(mockPluginInstance.Object, null); + + if (visitors != null) + { + var visitorsList = visitors.Invoke(); + foreach (var visitor in visitorsList) + { + mockPluginInstance.Object.AddVisitor(visitor); + } + } + else + { + var configureMethod = typeof(CodeModelGenerator).GetMethod( + "Configure", + BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.InvokeMethod + ); + configureMethod!.Invoke(mockPluginInstance.Object, null); + } + return mockPluginInstance; } diff --git a/eng/packages/http-client-csharp/generator/Azure.Generator/test/Visitors/SpecialHeadersVisitorTests.cs b/eng/packages/http-client-csharp/generator/Azure.Generator/test/Visitors/SpecialHeadersVisitorTests.cs index 38097603a828..7a8e85e67897 100644 --- a/eng/packages/http-client-csharp/generator/Azure.Generator/test/Visitors/SpecialHeadersVisitorTests.cs +++ b/eng/packages/http-client-csharp/generator/Azure.Generator/test/Visitors/SpecialHeadersVisitorTests.cs @@ -8,43 +8,72 @@ using Azure.Generator.Tests.Common; using Azure.Generator.Tests.TestHelpers; using Azure.Generator.Visitors; +using Microsoft.TypeSpec.Generator.ClientModel.Providers; namespace Azure.Generator.Tests.Visitors { public class SpecialHeadersVisitorTests { - [Test] - public void RemovesSpecialHeaderParametersFromServiceMethod() + [TestCase(true)] + [TestCase(false)] + public void RemovesSpecialHeaderParametersFromServiceMethod(bool addBackXMsClientRequestId) { - var visitor = new TestSpecialHeadersVisitor(); + var visitor = new TestSpecialHeadersVisitor(addBackXMsClientRequestId); var parameters = CreateHttpParameters(); var methodParameters = CreateMethodParameters(); var responseModel = InputFactory.Model("foo"); - var operation = InputFactory.Operation( - "foo", - parameters: parameters, - responses: [InputFactory.OperationResponse(bodytype: responseModel)]); - var serviceMethod = InputFactory.LongRunningServiceMethod( - "foo", - operation, - parameters: methodParameters, - response: InputFactory.ServiceMethodResponse(responseModel, ["result"])); - var inputClient = InputFactory.Client("TestClient", methods: [serviceMethod]); - MockHelpers.LoadMockGenerator(clients: () => [inputClient]); + var serviceMethods = new List(); - // Verify initial parameters include special headers - Assert.AreEqual(3, serviceMethod.Parameters.Count); - Assert.IsTrue(serviceMethod.Parameters.Any(p => p.SerializedName == "return-client-request-id")); - Assert.IsTrue(serviceMethod.Parameters.Any(p => p.SerializedName == "x-ms-client-request-id")); + // Create two operations and service methods + for (int i = 1; i <= 2; i++) + { + var operation = InputFactory.Operation( + $"operation{i}", + parameters: parameters, + responses: [InputFactory.OperationResponse(bodytype: responseModel)]); + var serviceMethod = InputFactory.LongRunningServiceMethod( + $"operation{i}", + operation, + parameters: methodParameters, + response: InputFactory.ServiceMethodResponse(responseModel, ["result"])); + serviceMethods.Add(serviceMethod); + } - // Act - this would normally be called by the visitor framework, but we'll test the core logic - visitor.InvokeRemoveSpecialHeaders(serviceMethod); + var inputClient = InputFactory.Client("TestClient", methods: serviceMethods); - // Verify special headers are removed - Assert.AreEqual(1, serviceMethod.Parameters.Count); - Assert.IsFalse(serviceMethod.Parameters.Any(p => p.SerializedName == "return-client-request-id")); - Assert.IsFalse(serviceMethod.Parameters.Any(p => p.SerializedName == "x-ms-client-request-id")); - Assert.IsTrue(serviceMethod.Parameters.Any(p => p.SerializedName == "some-other-parameter")); + // Verify initial parameters include special headers for both methods + foreach (var serviceMethod in serviceMethods) + { + Assert.AreEqual(3, serviceMethod.Parameters.Count); + Assert.IsTrue(serviceMethod.Parameters.Any(p => p.SerializedName == "return-client-request-id")); + Assert.IsTrue(serviceMethod.Parameters.Any(p => p.SerializedName == "x-ms-client-request-id")); + } + + var generator = MockHelpers.LoadMockGenerator( + clients: () => [inputClient], + visitors: () => [visitor]); + var client = generator.Object.OutputLibrary.TypeProviders.OfType().First(); + + // Verify special headers are removed from both methods + foreach (var serviceMethod in serviceMethods) + { + var methodCollection = client.GetMethodCollectionByOperation(serviceMethod.Operation); + visitor.InvokePreVisit(serviceMethod, client, methodCollection); + } + + foreach (var serviceMethod in serviceMethods) + { + var restClientMethod = client.RestClient.Methods.First(m => m.Signature.Name == $"Create{serviceMethod.Name}Request"); + + visitor.InvokeVisit((restClientMethod as ScmMethodProvider)!); + Assert.AreEqual(1, serviceMethod.Parameters.Count); + Assert.IsFalse(serviceMethod.Parameters.Any(p => p.SerializedName == "return-client-request-id")); + Assert.IsFalse(serviceMethod.Parameters.Any(p => p.SerializedName == "x-ms-client-request-id")); + Assert.IsTrue(serviceMethod.Parameters.Any(p => p.SerializedName == "some-other-parameter")); + + // Verify x-ms-client-request-id is added back in method body if specified + Assert.AreEqual(addBackXMsClientRequestId, restClientMethod.BodyStatements!.ToDisplayString().Contains("request.Headers.SetValue(\"x-ms-client-request-id\", request.ClientRequestId);")); + } } [Test] @@ -76,13 +105,17 @@ public void DoesNotChangeParametersWhenNoSpecialHeaders() operation, parameters: methodParameters, response: InputFactory.ServiceMethodResponse(responseModel, ["result"])); + var inputClient = InputFactory.Client("TestClient", methods: [serviceMethod]); + var generator = MockHelpers.LoadMockGenerator(clients: () => [inputClient]); + var client = generator.Object.OutputLibrary.TypeProviders.OfType().First(); + var methodCollection = client.GetMethodCollectionByOperation(operation); // Verify initial state Assert.AreEqual(1, serviceMethod.Parameters.Count); var originalParameter = serviceMethod.Parameters[0]; // Act - visitor.InvokeRemoveSpecialHeaders(serviceMethod); + visitor.InvokePreVisit(serviceMethod, client, methodCollection); // Verify no changes Assert.AreEqual(1, serviceMethod.Parameters.Count); @@ -136,19 +169,18 @@ private static List CreateHttpParameters() private class TestSpecialHeadersVisitor : SpecialHeadersVisitor { - public void InvokeRemoveSpecialHeaders(InputServiceMethod serviceMethod) + public TestSpecialHeadersVisitor(bool addBackXMsClientRequestId = false) + : base(addBackXMsClientRequestId) + { + } + public void InvokePreVisit(InputServiceMethod serviceMethod, ClientProvider client, ScmMethodProviderCollection methods) + { + base.Visit(serviceMethod, client, methods); + } + + public void InvokeVisit(ScmMethodProvider method) { - // Simulate the core logic of removing special headers - var returnClientRequestIdParameter = - serviceMethod.Parameters.FirstOrDefault(p => p.SerializedName == "return-client-request-id"); - var xMsClientRequestIdParameter = - serviceMethod.Parameters.FirstOrDefault(p => p.SerializedName == "x-ms-client-request-id"); - - if (returnClientRequestIdParameter != null || xMsClientRequestIdParameter != null) - { - serviceMethod.Update(parameters: [.. serviceMethod.Parameters.Where(p => p.SerializedName != "return-client-request-id" && p.SerializedName != "x-ms-client-request-id")]); - serviceMethod.Operation.Update(parameters: [.. serviceMethod.Operation.Parameters.Where(p => p.SerializedName != "return-client-request-id" && p.SerializedName != "x-ms-client-request-id")]); - } + base.VisitMethod(method); } } } diff --git a/eng/packages/plugins/client/Client.Plugin/ci.yml b/eng/packages/plugins/client/Client.Plugin/ci.yml index 8485fa4ad1ae..e26f690583a3 100644 --- a/eng/packages/plugins/client/Client.Plugin/ci.yml +++ b/eng/packages/plugins/client/Client.Plugin/ci.yml @@ -5,6 +5,7 @@ trigger: paths: include: - eng/packages/plugins/client/Client.Plugin + - eng/packages/http-client-csharp pr: branches: include: @@ -15,6 +16,7 @@ pr: paths: include: - eng/packages/plugins/client/Client.Plugin + - eng/packages/http-client-csharp variables: - template: /eng/pipelines/templates/variables/image.yml diff --git a/eng/packages/plugins/client/Client.Plugin/src/Client.Plugin.csproj b/eng/packages/plugins/client/Client.Plugin/src/Client.Plugin.csproj index 3076ce2177e7..167e972f5755 100644 --- a/eng/packages/plugins/client/Client.Plugin/src/Client.Plugin.csproj +++ b/eng/packages/plugins/client/Client.Plugin/src/Client.Plugin.csproj @@ -14,15 +14,15 @@ - $(MSBuildThisFileDirectory)..\..\..\..\http-client-csharp\generator\Azure.Generator\src\Visitors + $(MSBuildThisFileDirectory)..\..\..\..\http-client-csharp\generator\Azure.Generator\src\ - - - - - + + + + + diff --git a/eng/packages/plugins/client/Client.Plugin/src/ClientPlugin.cs b/eng/packages/plugins/client/Client.Plugin/src/ClientPlugin.cs index e234953472b8..f0b29c56205d 100644 --- a/eng/packages/plugins/client/Client.Plugin/src/ClientPlugin.cs +++ b/eng/packages/plugins/client/Client.Plugin/src/ClientPlugin.cs @@ -19,7 +19,7 @@ public override void Apply(CodeModelGenerator generator) // Rest of the visitors can be added in any order. generator.AddVisitor(new NamespaceVisitor()); - generator.AddVisitor(new SpecialHeadersVisitor()); + generator.AddVisitor(new SpecialHeadersVisitor(includeXmsClientRequestIdInRequest: true)); } } } diff --git a/eng/service.proj b/eng/service.proj index 801e50b1191b..9bb44b519c2d 100644 --- a/eng/service.proj +++ b/eng/service.proj @@ -107,7 +107,11 @@ SkipNonexistentTargets="true" /> - + + + + + - + - + -