|
8 | 8 | using Azure.Generator.Tests.Common; |
9 | 9 | using Azure.Generator.Tests.TestHelpers; |
10 | 10 | using Azure.Generator.Visitors; |
| 11 | +using Microsoft.TypeSpec.Generator.ClientModel.Providers; |
11 | 12 |
|
12 | 13 | namespace Azure.Generator.Tests.Visitors |
13 | 14 | { |
14 | 15 | public class SpecialHeadersVisitorTests |
15 | 16 | { |
16 | | - [Test] |
17 | | - public void RemovesSpecialHeaderParametersFromServiceMethod() |
| 17 | + [TestCase(true)] |
| 18 | + [TestCase(false)] |
| 19 | + public void RemovesSpecialHeaderParametersFromServiceMethod(bool addBackXMsClientRequestId) |
18 | 20 | { |
19 | | - var visitor = new TestSpecialHeadersVisitor(); |
| 21 | + var visitor = new TestSpecialHeadersVisitor(addBackXMsClientRequestId); |
20 | 22 | var parameters = CreateHttpParameters(); |
21 | 23 | var methodParameters = CreateMethodParameters(); |
22 | 24 | var responseModel = InputFactory.Model("foo"); |
23 | | - var operation = InputFactory.Operation( |
24 | | - "foo", |
25 | | - parameters: parameters, |
26 | | - responses: [InputFactory.OperationResponse(bodytype: responseModel)]); |
27 | | - var serviceMethod = InputFactory.LongRunningServiceMethod( |
28 | | - "foo", |
29 | | - operation, |
30 | | - parameters: methodParameters, |
31 | | - response: InputFactory.ServiceMethodResponse(responseModel, ["result"])); |
32 | | - var inputClient = InputFactory.Client("TestClient", methods: [serviceMethod]); |
33 | | - MockHelpers.LoadMockGenerator(clients: () => [inputClient]); |
| 25 | + var serviceMethods = new List<InputServiceMethod>(); |
34 | 26 |
|
35 | | - // Verify initial parameters include special headers |
36 | | - Assert.AreEqual(3, serviceMethod.Parameters.Count); |
37 | | - Assert.IsTrue(serviceMethod.Parameters.Any(p => p.SerializedName == "return-client-request-id")); |
38 | | - Assert.IsTrue(serviceMethod.Parameters.Any(p => p.SerializedName == "x-ms-client-request-id")); |
| 27 | + // Create two operations and service methods |
| 28 | + for (int i = 1; i <= 2; i++) |
| 29 | + { |
| 30 | + var operation = InputFactory.Operation( |
| 31 | + $"operation{i}", |
| 32 | + parameters: parameters, |
| 33 | + responses: [InputFactory.OperationResponse(bodytype: responseModel)]); |
| 34 | + var serviceMethod = InputFactory.LongRunningServiceMethod( |
| 35 | + $"operation{i}", |
| 36 | + operation, |
| 37 | + parameters: methodParameters, |
| 38 | + response: InputFactory.ServiceMethodResponse(responseModel, ["result"])); |
| 39 | + serviceMethods.Add(serviceMethod); |
| 40 | + } |
39 | 41 |
|
40 | | - // Act - this would normally be called by the visitor framework, but we'll test the core logic |
41 | | - visitor.InvokeRemoveSpecialHeaders(serviceMethod); |
| 42 | + var inputClient = InputFactory.Client("TestClient", methods: serviceMethods); |
42 | 43 |
|
43 | | - // Verify special headers are removed |
44 | | - Assert.AreEqual(1, serviceMethod.Parameters.Count); |
45 | | - Assert.IsFalse(serviceMethod.Parameters.Any(p => p.SerializedName == "return-client-request-id")); |
46 | | - Assert.IsFalse(serviceMethod.Parameters.Any(p => p.SerializedName == "x-ms-client-request-id")); |
47 | | - Assert.IsTrue(serviceMethod.Parameters.Any(p => p.SerializedName == "some-other-parameter")); |
| 44 | + // Verify initial parameters include special headers for both methods |
| 45 | + foreach (var serviceMethod in serviceMethods) |
| 46 | + { |
| 47 | + Assert.AreEqual(3, serviceMethod.Parameters.Count); |
| 48 | + Assert.IsTrue(serviceMethod.Parameters.Any(p => p.SerializedName == "return-client-request-id")); |
| 49 | + Assert.IsTrue(serviceMethod.Parameters.Any(p => p.SerializedName == "x-ms-client-request-id")); |
| 50 | + } |
| 51 | + |
| 52 | + var generator = MockHelpers.LoadMockGenerator( |
| 53 | + clients: () => [inputClient], |
| 54 | + visitors: () => [visitor]); |
| 55 | + var client = generator.Object.OutputLibrary.TypeProviders.OfType<ClientProvider>().First(); |
| 56 | + |
| 57 | + // Verify special headers are removed from both methods |
| 58 | + foreach (var serviceMethod in serviceMethods) |
| 59 | + { |
| 60 | + var methodCollection = client.GetMethodCollectionByOperation(serviceMethod.Operation); |
| 61 | + visitor.InvokePreVisit(serviceMethod, client, methodCollection); |
| 62 | + } |
| 63 | + |
| 64 | + foreach (var serviceMethod in serviceMethods) |
| 65 | + { |
| 66 | + var restClientMethod = client.RestClient.Methods.First(m => m.Signature.Name == $"Create{serviceMethod.Name}Request"); |
| 67 | + |
| 68 | + visitor.InvokeVisit((restClientMethod as ScmMethodProvider)!); |
| 69 | + Assert.AreEqual(1, serviceMethod.Parameters.Count); |
| 70 | + Assert.IsFalse(serviceMethod.Parameters.Any(p => p.SerializedName == "return-client-request-id")); |
| 71 | + Assert.IsFalse(serviceMethod.Parameters.Any(p => p.SerializedName == "x-ms-client-request-id")); |
| 72 | + Assert.IsTrue(serviceMethod.Parameters.Any(p => p.SerializedName == "some-other-parameter")); |
| 73 | + |
| 74 | + // Verify x-ms-client-request-id is added back in method body if specified |
| 75 | + Assert.AreEqual(addBackXMsClientRequestId, restClientMethod.BodyStatements!.ToDisplayString().Contains("request.Headers.SetValue(\"x-ms-client-request-id\", request.ClientRequestId);")); |
| 76 | + } |
48 | 77 | } |
49 | 78 |
|
50 | 79 | [Test] |
@@ -76,13 +105,17 @@ public void DoesNotChangeParametersWhenNoSpecialHeaders() |
76 | 105 | operation, |
77 | 106 | parameters: methodParameters, |
78 | 107 | response: InputFactory.ServiceMethodResponse(responseModel, ["result"])); |
| 108 | + var inputClient = InputFactory.Client("TestClient", methods: [serviceMethod]); |
| 109 | + var generator = MockHelpers.LoadMockGenerator(clients: () => [inputClient]); |
| 110 | + var client = generator.Object.OutputLibrary.TypeProviders.OfType<ClientProvider>().First(); |
| 111 | + var methodCollection = client.GetMethodCollectionByOperation(operation); |
79 | 112 |
|
80 | 113 | // Verify initial state |
81 | 114 | Assert.AreEqual(1, serviceMethod.Parameters.Count); |
82 | 115 | var originalParameter = serviceMethod.Parameters[0]; |
83 | 116 |
|
84 | 117 | // Act |
85 | | - visitor.InvokeRemoveSpecialHeaders(serviceMethod); |
| 118 | + visitor.InvokePreVisit(serviceMethod, client, methodCollection); |
86 | 119 |
|
87 | 120 | // Verify no changes |
88 | 121 | Assert.AreEqual(1, serviceMethod.Parameters.Count); |
@@ -136,19 +169,18 @@ private static List<InputParameter> CreateHttpParameters() |
136 | 169 |
|
137 | 170 | private class TestSpecialHeadersVisitor : SpecialHeadersVisitor |
138 | 171 | { |
139 | | - public void InvokeRemoveSpecialHeaders(InputServiceMethod serviceMethod) |
| 172 | + public TestSpecialHeadersVisitor(bool addBackXMsClientRequestId = false) |
| 173 | + : base(addBackXMsClientRequestId) |
| 174 | + { |
| 175 | + } |
| 176 | + public void InvokePreVisit(InputServiceMethod serviceMethod, ClientProvider client, ScmMethodProviderCollection methods) |
| 177 | + { |
| 178 | + base.Visit(serviceMethod, client, methods); |
| 179 | + } |
| 180 | + |
| 181 | + public void InvokeVisit(ScmMethodProvider method) |
140 | 182 | { |
141 | | - // Simulate the core logic of removing special headers |
142 | | - var returnClientRequestIdParameter = |
143 | | - serviceMethod.Parameters.FirstOrDefault(p => p.SerializedName == "return-client-request-id"); |
144 | | - var xMsClientRequestIdParameter = |
145 | | - serviceMethod.Parameters.FirstOrDefault(p => p.SerializedName == "x-ms-client-request-id"); |
146 | | - |
147 | | - if (returnClientRequestIdParameter != null || xMsClientRequestIdParameter != null) |
148 | | - { |
149 | | - serviceMethod.Update(parameters: [.. serviceMethod.Parameters.Where(p => p.SerializedName != "return-client-request-id" && p.SerializedName != "x-ms-client-request-id")]); |
150 | | - serviceMethod.Operation.Update(parameters: [.. serviceMethod.Operation.Parameters.Where(p => p.SerializedName != "return-client-request-id" && p.SerializedName != "x-ms-client-request-id")]); |
151 | | - } |
| 183 | + base.VisitMethod(method); |
152 | 184 | } |
153 | 185 | } |
154 | 186 | } |
|
0 commit comments