Skip to content

Commit c37a2f3

Browse files
committed
Just plugin changes
1 parent 3dad518 commit c37a2f3

File tree

8 files changed

+180
-91
lines changed

8 files changed

+180
-91
lines changed

eng/CodeGeneration.targets

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,11 @@
2121
<Exec Command="npm ci --prefix $(_TspClientDir)" />
2222
</Target>
2323

24-
<Target Name="GenerateCode" Condition="'$(TypeSpecInput)' != ''" DependsOnTargets="InstallTspClient">
24+
<Target Name="BuildPlugin" Condition="'$(TypeSpecInput)' != '' AND '$(SkipBuildPlugin)' != 'true'">
25+
<Exec Command="dotnet build $(MSBuildThisFileDirectory)packages/plugins/client/Client.Plugin/" />
26+
</Target>
27+
28+
<Target Name="GenerateCode" Condition="'$(TypeSpecInput)' != ''" DependsOnTargets="InstallTspClient;BuildPlugin">
2529
<Error Text="You used skipped sync but didn't have the TempTypeSpecFiles in your project directory. Please run 'dotnet build /t:GenerateCode /p:SaveInputs=true' without SkipSync first at least once" Condition="'$(SkipSync)' == 'true' AND !Exists('$(MSBuildProjectDirectory)/../TempTypeSpecFiles')" />
2630
<Exec Condition="'$(SkipSync)' == 'true'" Command="$(_TypeSpecProjectGenerateCommand) $(_SaveInputs) $(_TypespecAdditionalOptions) $(_Trace)"/>
2731
<Exec Condition="'$(SkipSync)' != 'true'" Command="$(_TypeSpecProjectSyncAndGenerateCommand) $(_SaveInputs) $(_LocalSpecRepo) $(_TypespecAdditionalOptions) $(_Trace)"/>
Lines changed: 67 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
// Copyright (c) Microsoft Corporation. All rights reserved.
22
// Licensed under the MIT License.
33

4-
using System.Collections.Generic;
5-
using System.Linq;
64
using Microsoft.TypeSpec.Generator.ClientModel;
75
using Microsoft.TypeSpec.Generator.ClientModel.Providers;
86
using Microsoft.TypeSpec.Generator.Expressions;
97
using Microsoft.TypeSpec.Generator.Input;
8+
using Microsoft.TypeSpec.Generator.Providers;
109
using Microsoft.TypeSpec.Generator.Statements;
10+
using System.Collections.Generic;
11+
using System.Linq;
1112
using static Microsoft.TypeSpec.Generator.Snippets.Snippet;
1213

1314
namespace Azure.Generator.Visitors
@@ -17,8 +18,16 @@ namespace Azure.Generator.Visitors
1718
/// </summary>
1819
internal class SpecialHeadersVisitor : ScmLibraryVisitor
1920
{
21+
private readonly bool _includeClientRequestIdInRequest;
2022
private const string ReturnClientRequestIdParameterName = "return-client-request-id";
2123
private const string XMsClientRequestIdParameterName = "x-ms-client-request-id";
24+
private readonly Dictionary<InputServiceMethod, (InputMethodParameter? ReturnClientRequestId, InputMethodParameter? XmsClientRequestId)> _serviceMethodParameterMap;
25+
26+
public SpecialHeadersVisitor(bool includeXmsClientRequestIdInRequest = false)
27+
{
28+
_includeClientRequestIdInRequest = includeXmsClientRequestIdInRequest;
29+
_serviceMethodParameterMap = [];
30+
}
2231

2332
protected override ScmMethodProviderCollection? Visit(
2433
InputServiceMethod serviceMethod,
@@ -38,52 +47,76 @@ internal class SpecialHeadersVisitor : ScmLibraryVisitor
3847
// Create a new method collection with the updated service method
3948
methods = new ScmMethodProviderCollection(serviceMethod, client);
4049

41-
// Reset the rest client so that its methods are rebuilt.
50+
// Store the parameters for the CreateRequest method
51+
_serviceMethodParameterMap.TryAdd(serviceMethod, (returnClientRequestIdParameter, xMsClientRequestIdParameter));
52+
53+
// Reset the rest client so that its methods are rebuilt reflecting the new signatures
4254
client.RestClient.Reset();
43-
var createRequestMethod = client.RestClient.GetCreateRequestMethod(serviceMethod.Operation);
55+
}
56+
57+
return methods;
58+
}
59+
60+
protected override ScmMethodProvider? VisitMethod(ScmMethodProvider method)
61+
{
62+
if (method.ServiceMethod is null || !_serviceMethodParameterMap.TryGetValue(method.ServiceMethod, out var parameters))
63+
{
64+
return method;
65+
}
4466

45-
var originalBodyStatements = createRequestMethod.BodyStatements!.ToList();
67+
var (returnClientRequestIdParameter, xMsClientRequestIdParameter) = parameters;
4668

47-
// Exclude the last statement which is the return statement. We will add it back later.
48-
var newStatements = new List<MethodBodyStatement>(originalBodyStatements[..^1]);
69+
var originalBodyStatements = method.BodyStatements!.ToList();
4970

50-
// Find the request variable
51-
VariableExpression? requestVariable = null;
52-
foreach (var statement in newStatements)
71+
// Exclude the last statement which is the return statement. We will add it back later.
72+
var newStatements = new List<MethodBodyStatement>(originalBodyStatements[..^1]);
73+
74+
// Find the request variable
75+
VariableExpression? requestVariable = null;
76+
foreach (var statement in newStatements)
77+
{
78+
if (statement is ExpressionStatement
79+
{
80+
Expression: AssignmentExpression { Variable: DeclarationExpression declaration }
81+
})
5382
{
54-
if (statement is ExpressionStatement
55-
{
56-
Expression: AssignmentExpression { Variable: DeclarationExpression declaration }
57-
})
83+
var variable = declaration.Variable;
84+
if (variable.Type.Equals(variable.ToApi<HttpRequestApi>().Type))
5885
{
59-
var variable = declaration.Variable;
60-
if (variable.Type.Equals(variable.ToApi<HttpRequestApi>().Type))
61-
{
62-
requestVariable = variable;
63-
}
86+
requestVariable = variable;
6487
}
6588
}
89+
}
6690

67-
if (returnClientRequestIdParameter?.DefaultValue?.Value != null)
91+
var request = requestVariable?.ToApi<HttpRequestApi>();
92+
if (request != null && returnClientRequestIdParameter?.DefaultValue?.Value != null)
93+
{
94+
if (bool.TryParse(returnClientRequestIdParameter.DefaultValue.Value.ToString(), out bool value))
6895
{
69-
if (bool.TryParse(returnClientRequestIdParameter.DefaultValue.Value.ToString(), out bool value))
70-
{
71-
// Set the return-client-request-id header
72-
newStatements.Add(requestVariable!.ToApi<HttpRequestApi>().SetHeaders(
73-
[
74-
Literal(returnClientRequestIdParameter.SerializedName),
75-
Literal(value.ToString().ToLowerInvariant())
76-
]));
77-
}
96+
// Set the return-client-request-id header
97+
newStatements.Add(request.SetHeaders(
98+
[
99+
Literal(returnClientRequestIdParameter.SerializedName),
100+
Literal(value.ToString().ToLowerInvariant())
101+
]));
78102
}
103+
}
79104

80-
// Add the return statement back
81-
newStatements.Add(originalBodyStatements[^1]);
82-
83-
createRequestMethod.Update(bodyStatements: newStatements);
105+
if (request != null && _includeClientRequestIdInRequest && xMsClientRequestIdParameter != null)
106+
{
107+
// Set the x-ms-client-request-id header
108+
newStatements.Add(request.SetHeaders(
109+
[
110+
Literal(xMsClientRequestIdParameter.SerializedName),
111+
request.Property("ClientRequestId")
112+
]));
84113
}
85114

86-
return methods;
115+
// Add the return statement back
116+
newStatements.Add(originalBodyStatements[^1]);
117+
118+
method.Update(bodyStatements: newStatements);
119+
return method;
87120
}
88121
}
89122
}

eng/packages/http-client-csharp/generator/Azure.Generator/test/TestHelpers/MockHelpers.cs

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ public static Mock<AzureClientGenerator> LoadMockGenerator(
3333
Func<IReadOnlyList<InputModelType>>? inputModels = null,
3434
Func<IReadOnlyList<InputClient>>? clients = null,
3535
Func<InputClient, ClientProvider?>? createClientCore = null,
36+
Func<IReadOnlyList<ScmLibraryVisitor>>? visitors = null,
3637
ClientResponseApi? clientResponseApi = null,
3738
ClientPipelineApi? clientPipelineApi = null,
3839
HttpMessageApi? httpMessageApi = null,
@@ -91,11 +92,24 @@ public static Mock<AzureClientGenerator> LoadMockGenerator(
9192

9293
var sourceInputModel = new Mock<SourceInputModel>(() => new SourceInputModel(null, null)) { CallBase = true };
9394
mockPluginInstance.Setup(p => p.SourceInputModel).Returns(sourceInputModel.Object);
94-
var configureMethod = typeof(CodeModelGenerator).GetMethod(
95-
"Configure",
96-
BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.InvokeMethod
97-
);
98-
configureMethod!.Invoke(mockPluginInstance.Object, null);
95+
96+
if (visitors != null)
97+
{
98+
var visitorsList = visitors.Invoke();
99+
foreach (var visitor in visitorsList)
100+
{
101+
mockPluginInstance.Object.AddVisitor(visitor);
102+
}
103+
}
104+
else
105+
{
106+
var configureMethod = typeof(CodeModelGenerator).GetMethod(
107+
"Configure",
108+
BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.InvokeMethod
109+
);
110+
configureMethod!.Invoke(mockPluginInstance.Object, null);
111+
}
112+
99113
return mockPluginInstance;
100114
}
101115

eng/packages/http-client-csharp/generator/Azure.Generator/test/Visitors/SpecialHeadersVisitorTests.cs

Lines changed: 70 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -8,43 +8,72 @@
88
using Azure.Generator.Tests.Common;
99
using Azure.Generator.Tests.TestHelpers;
1010
using Azure.Generator.Visitors;
11+
using Microsoft.TypeSpec.Generator.ClientModel.Providers;
1112

1213
namespace Azure.Generator.Tests.Visitors
1314
{
1415
public class SpecialHeadersVisitorTests
1516
{
16-
[Test]
17-
public void RemovesSpecialHeaderParametersFromServiceMethod()
17+
[TestCase(true)]
18+
[TestCase(false)]
19+
public void RemovesSpecialHeaderParametersFromServiceMethod(bool addBackXMsClientRequestId)
1820
{
19-
var visitor = new TestSpecialHeadersVisitor();
21+
var visitor = new TestSpecialHeadersVisitor(addBackXMsClientRequestId);
2022
var parameters = CreateHttpParameters();
2123
var methodParameters = CreateMethodParameters();
2224
var responseModel = InputFactory.Model("foo");
23-
var operation = InputFactory.Operation(
24-
"foo",
25-
parameters: parameters,
26-
responses: [InputFactory.OperationResponse(bodytype: responseModel)]);
27-
var serviceMethod = InputFactory.LongRunningServiceMethod(
28-
"foo",
29-
operation,
30-
parameters: methodParameters,
31-
response: InputFactory.ServiceMethodResponse(responseModel, ["result"]));
32-
var inputClient = InputFactory.Client("TestClient", methods: [serviceMethod]);
33-
MockHelpers.LoadMockGenerator(clients: () => [inputClient]);
25+
var serviceMethods = new List<InputServiceMethod>();
3426

35-
// Verify initial parameters include special headers
36-
Assert.AreEqual(3, serviceMethod.Parameters.Count);
37-
Assert.IsTrue(serviceMethod.Parameters.Any(p => p.SerializedName == "return-client-request-id"));
38-
Assert.IsTrue(serviceMethod.Parameters.Any(p => p.SerializedName == "x-ms-client-request-id"));
27+
// Create two operations and service methods
28+
for (int i = 1; i <= 2; i++)
29+
{
30+
var operation = InputFactory.Operation(
31+
$"operation{i}",
32+
parameters: parameters,
33+
responses: [InputFactory.OperationResponse(bodytype: responseModel)]);
34+
var serviceMethod = InputFactory.LongRunningServiceMethod(
35+
$"operation{i}",
36+
operation,
37+
parameters: methodParameters,
38+
response: InputFactory.ServiceMethodResponse(responseModel, ["result"]));
39+
serviceMethods.Add(serviceMethod);
40+
}
3941

40-
// Act - this would normally be called by the visitor framework, but we'll test the core logic
41-
visitor.InvokeRemoveSpecialHeaders(serviceMethod);
42+
var inputClient = InputFactory.Client("TestClient", methods: serviceMethods);
4243

43-
// Verify special headers are removed
44-
Assert.AreEqual(1, serviceMethod.Parameters.Count);
45-
Assert.IsFalse(serviceMethod.Parameters.Any(p => p.SerializedName == "return-client-request-id"));
46-
Assert.IsFalse(serviceMethod.Parameters.Any(p => p.SerializedName == "x-ms-client-request-id"));
47-
Assert.IsTrue(serviceMethod.Parameters.Any(p => p.SerializedName == "some-other-parameter"));
44+
// Verify initial parameters include special headers for both methods
45+
foreach (var serviceMethod in serviceMethods)
46+
{
47+
Assert.AreEqual(3, serviceMethod.Parameters.Count);
48+
Assert.IsTrue(serviceMethod.Parameters.Any(p => p.SerializedName == "return-client-request-id"));
49+
Assert.IsTrue(serviceMethod.Parameters.Any(p => p.SerializedName == "x-ms-client-request-id"));
50+
}
51+
52+
var generator = MockHelpers.LoadMockGenerator(
53+
clients: () => [inputClient],
54+
visitors: () => [visitor]);
55+
var client = generator.Object.OutputLibrary.TypeProviders.OfType<ClientProvider>().First();
56+
57+
// Verify special headers are removed from both methods
58+
foreach (var serviceMethod in serviceMethods)
59+
{
60+
var methodCollection = client.GetMethodCollectionByOperation(serviceMethod.Operation);
61+
visitor.InvokePreVisit(serviceMethod, client, methodCollection);
62+
}
63+
64+
foreach (var serviceMethod in serviceMethods)
65+
{
66+
var restClientMethod = client.RestClient.Methods.First(m => m.Signature.Name == $"Create{serviceMethod.Name}Request");
67+
68+
visitor.InvokeVisit((restClientMethod as ScmMethodProvider)!);
69+
Assert.AreEqual(1, serviceMethod.Parameters.Count);
70+
Assert.IsFalse(serviceMethod.Parameters.Any(p => p.SerializedName == "return-client-request-id"));
71+
Assert.IsFalse(serviceMethod.Parameters.Any(p => p.SerializedName == "x-ms-client-request-id"));
72+
Assert.IsTrue(serviceMethod.Parameters.Any(p => p.SerializedName == "some-other-parameter"));
73+
74+
// Verify x-ms-client-request-id is added back in method body if specified
75+
Assert.AreEqual(addBackXMsClientRequestId, restClientMethod.BodyStatements!.ToDisplayString().Contains("request.Headers.SetValue(\"x-ms-client-request-id\", request.ClientRequestId);"));
76+
}
4877
}
4978

5079
[Test]
@@ -76,13 +105,17 @@ public void DoesNotChangeParametersWhenNoSpecialHeaders()
76105
operation,
77106
parameters: methodParameters,
78107
response: InputFactory.ServiceMethodResponse(responseModel, ["result"]));
108+
var inputClient = InputFactory.Client("TestClient", methods: [serviceMethod]);
109+
var generator = MockHelpers.LoadMockGenerator(clients: () => [inputClient]);
110+
var client = generator.Object.OutputLibrary.TypeProviders.OfType<ClientProvider>().First();
111+
var methodCollection = client.GetMethodCollectionByOperation(operation);
79112

80113
// Verify initial state
81114
Assert.AreEqual(1, serviceMethod.Parameters.Count);
82115
var originalParameter = serviceMethod.Parameters[0];
83116

84117
// Act
85-
visitor.InvokeRemoveSpecialHeaders(serviceMethod);
118+
visitor.InvokePreVisit(serviceMethod, client, methodCollection);
86119

87120
// Verify no changes
88121
Assert.AreEqual(1, serviceMethod.Parameters.Count);
@@ -136,19 +169,18 @@ private static List<InputParameter> CreateHttpParameters()
136169

137170
private class TestSpecialHeadersVisitor : SpecialHeadersVisitor
138171
{
139-
public void InvokeRemoveSpecialHeaders(InputServiceMethod serviceMethod)
172+
public TestSpecialHeadersVisitor(bool addBackXMsClientRequestId = false)
173+
: base(addBackXMsClientRequestId)
174+
{
175+
}
176+
public void InvokePreVisit(InputServiceMethod serviceMethod, ClientProvider client, ScmMethodProviderCollection methods)
177+
{
178+
base.Visit(serviceMethod, client, methods);
179+
}
180+
181+
public void InvokeVisit(ScmMethodProvider method)
140182
{
141-
// Simulate the core logic of removing special headers
142-
var returnClientRequestIdParameter =
143-
serviceMethod.Parameters.FirstOrDefault(p => p.SerializedName == "return-client-request-id");
144-
var xMsClientRequestIdParameter =
145-
serviceMethod.Parameters.FirstOrDefault(p => p.SerializedName == "x-ms-client-request-id");
146-
147-
if (returnClientRequestIdParameter != null || xMsClientRequestIdParameter != null)
148-
{
149-
serviceMethod.Update(parameters: [.. serviceMethod.Parameters.Where(p => p.SerializedName != "return-client-request-id" && p.SerializedName != "x-ms-client-request-id")]);
150-
serviceMethod.Operation.Update(parameters: [.. serviceMethod.Operation.Parameters.Where(p => p.SerializedName != "return-client-request-id" && p.SerializedName != "x-ms-client-request-id")]);
151-
}
183+
base.VisitMethod(method);
152184
}
153185
}
154186
}

eng/packages/plugins/client/Client.Plugin/ci.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ trigger:
55
paths:
66
include:
77
- eng/packages/plugins/client/Client.Plugin
8+
- eng/packages/http-client-csharp
89
pr:
910
branches:
1011
include:
@@ -15,6 +16,7 @@ pr:
1516
paths:
1617
include:
1718
- eng/packages/plugins/client/Client.Plugin
19+
- eng/packages/http-client-csharp
1820

1921
variables:
2022
- template: /eng/pipelines/templates/variables/image.yml

eng/packages/plugins/client/Client.Plugin/src/Client.Plugin.csproj

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,15 @@
1414

1515
<!-- Shared visitor source files -->
1616
<PropertyGroup>
17-
<VisitorsSharedSource>$(MSBuildThisFileDirectory)..\..\..\..\http-client-csharp\generator\Azure.Generator\src\Visitors</VisitorsSharedSource>
17+
<AzureGeneratorSource>$(MSBuildThisFileDirectory)..\..\..\..\http-client-csharp\generator\Azure.Generator\src\</AzureGeneratorSource>
1818
</PropertyGroup>
1919

2020
<ItemGroup>
21-
<Compile Include="$(VisitorsSharedSource)\NamespaceVisitor.cs" LinkBase="Shared/Visitors" />
22-
<Compile Include="$(VisitorsSharedSource)\SpecialHeadersVisitor.cs" LinkBase="Shared/Visitors" />
23-
<Compile Include="$(VisitorsSharedSource)\ModelFactoryRenamerVisitor.cs" LinkBase="Shared/Visitors" />
24-
<Compile Include="$(VisitorsSharedSource)\Extensions\ConfigurationExtensions.cs" LinkBase="Shared/Visitors/Extensions" />
25-
<Compile Include="$(VisitorsSharedSource)\Utilities\TypeNameUtilities.cs" LinkBase="Shared/Visitors/Utilities" />
21+
<Compile Include="$(AzureGeneratorSource)\Visitors\NamespaceVisitor.cs" LinkBase="Shared/" />
22+
<Compile Include="$(AzureGeneratorSource)\Visitors\SpecialHeadersVisitor.cs" LinkBase="Shared/" />
23+
<Compile Include="$(AzureGeneratorSource)\Visitors\ModelFactoryRenamerVisitor.cs" LinkBase="Shared/" />
24+
<Compile Include="$(AzureGeneratorSource)\Extensions\ConfigurationExtensions.cs" LinkBase="Shared/" />
25+
<Compile Include="$(AzureGeneratorSource)\Utilities\TypeNameUtilities.cs" LinkBase="Shared/" />
2626
</ItemGroup>
2727

2828
<!-- Copy output to package dist path -->

eng/packages/plugins/client/Client.Plugin/src/ClientPlugin.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ public override void Apply(CodeModelGenerator generator)
1919

2020
// Rest of the visitors can be added in any order.
2121
generator.AddVisitor(new NamespaceVisitor());
22-
generator.AddVisitor(new SpecialHeadersVisitor());
22+
generator.AddVisitor(new SpecialHeadersVisitor(includeXmsClientRequestIdInRequest: true));
2323
}
2424
}
2525
}

0 commit comments

Comments
 (0)