Skip to content

Commit dd0261d

Browse files
refactor
Signed-off-by: Aparajita Pandey <aparajita31pandey@gmail.com>
1 parent ab49942 commit dd0261d

File tree

5 files changed

+98
-24
lines changed

5 files changed

+98
-24
lines changed

CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
2929
- Refactor the ThreadPoolStats.Stats class to use the Builder pattern instead of constructors ([#19317](https://github.com/opensearch-project/OpenSearch/pull/19317))
3030
- Refactor the IndexingStats.Stats class to use the Builder pattern instead of constructors ([#19306](https://github.com/opensearch-project/OpenSearch/pull/19306))
3131
- Remove FeatureFlag.MERGED_SEGMENT_WARMER_EXPERIMENTAL_FLAG. ([#19715](https://github.com/opensearch-project/OpenSearch/pull/19715))
32-
-
32+
- Thread Context Preservation by gRPC Interceptor. ([#19776](https://github.com/opensearch-project/OpenSearch/pull/19776))
33+
3334
### Fixed
3435
- Fix Allocation and Rebalance Constraints of WeightFunction are incorrectly reset ([#19012](https://github.com/opensearch-project/OpenSearch/pull/19012))
3536
- Fix flaky test FieldDataLoadingIT.testIndicesFieldDataCacheSizeSetting ([#19571](https://github.com/opensearch-project/OpenSearch/pull/19571))

modules/transport-grpc/src/main/java/org/opensearch/transport/grpc/GrpcPlugin.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
import java.util.List;
5353
import java.util.Map;
5454
import java.util.function.Supplier;
55+
import java.util.stream.Collectors;
5556

5657
import io.grpc.BindableService;
5758

@@ -122,7 +123,6 @@ public void loadExtensions(ExtensiblePlugin.ExtensionLoader loader) {
122123
// Note: ThreadContext will be provided during component creation
123124
// For now, we collect providers to be initialized later with ThreadContext
124125
this.interceptorProviders = providers;
125-
126126
logger.info("Found {} gRPC interceptor providers, will initialize during component creation", providers.size());
127127
}
128128
// Load discovered gRPC service factories
@@ -358,9 +358,15 @@ public Collection<Object> createComponents(
358358
// Check for duplicates and throw exception if found
359359
for (Map.Entry<Integer, List<OrderedGrpcInterceptor>> entry : orderMap.entrySet()) {
360360
if (entry.getValue().size() > 1) {
361+
String conflictingInterceptors = entry.getValue()
362+
.stream()
363+
.map(i -> i.getInterceptor().getClass().getName())
364+
.collect(Collectors.joining(", "));
361365
throw new IllegalArgumentException(
362-
"Multiple gRPC interceptors have the same order value: "
366+
"Multiple gRPC interceptors have the same order value ["
363367
+ entry.getKey()
368+
+ "]: "
369+
+ conflictingInterceptors
364370
+ ". Each interceptor must have a unique order value."
365371
);
366372
}

modules/transport-grpc/src/test/java/org/opensearch/transport/grpc/GrpcPluginTests.java

Lines changed: 70 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -101,11 +101,14 @@ public void setup() {
101101
// Create a real ClusterSettings instance with the plugin's settings
102102
plugin = new GrpcPlugin();
103103

104+
// Mock ThreadPool and ThreadContext
105+
when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY));
106+
104107
// Set the client in the plugin
105108
plugin.createComponents(
106109
client,
107110
null, // ClusterService
108-
null, // ThreadPool
111+
threadPool, // ThreadPool (now properly mocked)
109112
null, // ResourceWatcherService
110113
null, // ScriptService
111114
null, // NamedXContentRegistry
@@ -255,7 +258,9 @@ public void testGetSecureAuxTransportsWithNullClient() {
255258

256259
public void testGetAuxTransportsWithServiceFactories() {
257260
GrpcPlugin newPlugin = new GrpcPlugin();
258-
newPlugin.createComponents(Mockito.mock(Client.class), null, null, null, null, null, null, null, null, null, null);
261+
ThreadPool mockThreadPool = Mockito.mock(ThreadPool.class);
262+
when(mockThreadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY));
263+
newPlugin.createComponents(Mockito.mock(Client.class), null, mockThreadPool, null, null, null, null, null, null, null, null);
259264
ExtensiblePlugin.ExtensionLoader mockLoader = Mockito.mock(ExtensiblePlugin.ExtensionLoader.class);
260265
when(mockLoader.loadExtensions(GrpcServiceFactory.class)).thenReturn(List.of(new LoadableMockServiceFactory()));
261266
plugin.loadExtensions(mockLoader);
@@ -274,7 +279,9 @@ public void testGetAuxTransportsWithServiceFactories() {
274279

275280
public void testGetSecureAuxTransportsWithServiceFactories() {
276281
GrpcPlugin newPlugin = new GrpcPlugin();
277-
newPlugin.createComponents(Mockito.mock(Client.class), null, null, null, null, null, null, null, null, null, null);
282+
ThreadPool mockThreadPool = Mockito.mock(ThreadPool.class);
283+
when(mockThreadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY));
284+
newPlugin.createComponents(Mockito.mock(Client.class), null, mockThreadPool, null, null, null, null, null, null, null, null);
278285
ExtensiblePlugin.ExtensionLoader mockLoader = Mockito.mock(ExtensiblePlugin.ExtensionLoader.class);
279286
when(mockLoader.loadExtensions(GrpcServiceFactory.class)).thenReturn(List.of(new LoadableMockServiceFactory()));
280287
plugin.loadExtensions(mockLoader);
@@ -430,11 +437,45 @@ public void testLoadExtensionsWithGrpcInterceptorsOrdering() {
430437
}
431438

432439
public void testLoadExtensionsWithDuplicateGrpcInterceptorOrder() {
433-
testInterceptorLoading(List.of(1, 1), IllegalArgumentException.class);
440+
GrpcPlugin plugin = new GrpcPlugin();
441+
ExtensiblePlugin.ExtensionLoader mockLoader = createMockLoader(List.of(1, 1));
442+
443+
assertDoesNotThrow(() -> plugin.loadExtensions(mockLoader));
444+
445+
ThreadPool mockThreadPool = Mockito.mock(ThreadPool.class);
446+
when(mockThreadPool.getThreadContext()).thenReturn(new org.opensearch.common.util.concurrent.ThreadContext(Settings.EMPTY));
447+
448+
IllegalArgumentException exception = expectThrows(
449+
IllegalArgumentException.class,
450+
() -> plugin.createComponents(client, null, mockThreadPool, null, null, null, null, null, null, null, null)
451+
);
452+
453+
String errorMessage = exception.getMessage();
454+
assertTrue(errorMessage.contains("Multiple gRPC interceptors have the same order value [1]"));
455+
assertTrue(errorMessage.contains("ServerInterceptor")); // Mock class name will contain this
456+
assertTrue(errorMessage.contains("Each interceptor must have a unique order value"));
434457
}
435458

436459
public void testLoadExtensionsWithMultipleProvidersAndDuplicateOrder() {
437-
testInterceptorLoadingWithMultipleProviders(List.of(List.of(5), List.of(5)), IllegalArgumentException.class);
460+
GrpcPlugin plugin = new GrpcPlugin();
461+
ExtensiblePlugin.ExtensionLoader mockLoader = createMockLoaderWithMultipleProviders(List.of(List.of(5), List.of(5)));
462+
463+
// loadExtensions should succeed
464+
assertDoesNotThrow(() -> plugin.loadExtensions(mockLoader));
465+
466+
// createComponents should fail with duplicate order
467+
ThreadPool mockThreadPool = Mockito.mock(ThreadPool.class);
468+
when(mockThreadPool.getThreadContext()).thenReturn(new org.opensearch.common.util.concurrent.ThreadContext(Settings.EMPTY));
469+
470+
IllegalArgumentException exception = expectThrows(
471+
IllegalArgumentException.class,
472+
() -> plugin.createComponents(client, null, mockThreadPool, null, null, null, null, null, null, null, null)
473+
);
474+
475+
String errorMessage = exception.getMessage();
476+
assertTrue(errorMessage.contains("Multiple gRPC interceptors have the same order value [5]"));
477+
assertTrue(errorMessage.contains("ServerInterceptor"));
478+
assertTrue(errorMessage.contains("Each interceptor must have a unique order value"));
438479
}
439480

440481
public void testLoadExtensionsWithNullGrpcInterceptorProviders() {
@@ -446,7 +487,23 @@ public void testLoadExtensionsWithEmptyGrpcInterceptorList() {
446487
}
447488

448489
public void testLoadExtensionsWithSameExplicitOrderInterceptors() {
449-
testInterceptorLoading(List.of(5, 5), IllegalArgumentException.class);
490+
GrpcPlugin plugin = new GrpcPlugin();
491+
ExtensiblePlugin.ExtensionLoader mockLoader = createMockLoader(List.of(5, 5));
492+
493+
assertDoesNotThrow(() -> plugin.loadExtensions(mockLoader));
494+
495+
ThreadPool mockThreadPool = Mockito.mock(ThreadPool.class);
496+
when(mockThreadPool.getThreadContext()).thenReturn(new org.opensearch.common.util.concurrent.ThreadContext(Settings.EMPTY));
497+
498+
IllegalArgumentException exception = expectThrows(
499+
IllegalArgumentException.class,
500+
() -> plugin.createComponents(client, null, mockThreadPool, null, null, null, null, null, null, null, null)
501+
);
502+
503+
String errorMessage = exception.getMessage();
504+
assertTrue(errorMessage.contains("Multiple gRPC interceptors have the same order value [5]"));
505+
assertTrue(errorMessage.contains("ServerInterceptor"));
506+
assertTrue(errorMessage.contains("Each interceptor must have a unique order value"));
450507
}
451508

452509
// Test cases for interceptor chain failure handling
@@ -867,10 +924,16 @@ public void testGrpcInterceptorChainWithDuplicateOrders() {
867924
when(mockThreadPool.getThreadContext()).thenReturn(new org.opensearch.common.util.concurrent.ThreadContext(Settings.EMPTY));
868925

869926
// Should throw exception due to duplicate orders during createComponents
870-
expectThrows(
927+
IllegalArgumentException exception = expectThrows(
871928
IllegalArgumentException.class,
872929
() -> plugin.createComponents(client, null, mockThreadPool, null, null, null, null, null, null, null, null)
873930
);
931+
932+
// Verify error message includes order value and interceptor class names
933+
String errorMessage = exception.getMessage();
934+
assertTrue(errorMessage.contains("Multiple gRPC interceptors have the same order value [10]"));
935+
assertTrue(errorMessage.contains("GrpcPluginTests"));
936+
assertTrue(errorMessage.contains("Each interceptor must have a unique order value"));
874937
}
875938

876939
/**

modules/transport-grpc/src/test/java/org/opensearch/transport/grpc/interceptor/GrpcInterceptorChainTests.java

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
package org.opensearch.transport.grpc.interceptor;
1010

11+
import org.opensearch.common.settings.Settings;
12+
import org.opensearch.common.util.concurrent.ThreadContext;
1113
import org.opensearch.test.OpenSearchTestCase;
1214
import org.opensearch.transport.grpc.spi.GrpcInterceptorProvider.OrderedGrpcInterceptor;
1315
import org.junit.Before;
@@ -46,28 +48,30 @@ public class GrpcInterceptorChainTests extends OpenSearchTestCase {
4648
private ServerCall.Listener<String> mockListener;
4749

4850
private Metadata headers;
51+
private ThreadContext threadContext;
4952

5053
@Before
5154
public void setUp() throws Exception {
5255
super.setUp();
5356
MockitoAnnotations.openMocks(this);
5457
when(mockHandler.startCall(any(), any())).thenReturn(mockListener);
5558
headers = new Metadata();
59+
threadContext = new ThreadContext(Settings.EMPTY);
5660
}
5761

5862
public void testEmptyChain() {
59-
GrpcInterceptorChain chain = new GrpcInterceptorChain(Collections.emptyList());
63+
GrpcInterceptorChain chain = new GrpcInterceptorChain(threadContext, Collections.emptyList());
6064
ServerCall.Listener<String> result = chain.interceptCall(mockCall, headers, mockHandler);
6165

6266
assertNotNull(result);
63-
assertEquals(mockListener, result);
67+
// The result is now wrapped in a ThreadContextPreservingListener, not the raw mockListener
6468
verify(mockHandler).startCall(mockCall, headers);
6569
}
6670

6771
public void testSingleSuccessfulInterceptor() {
6872
List<OrderedGrpcInterceptor> interceptors = Arrays.asList(createTestInterceptor(10, false, null));
6973

70-
GrpcInterceptorChain chain = new GrpcInterceptorChain(interceptors);
74+
GrpcInterceptorChain chain = new GrpcInterceptorChain(threadContext, interceptors);
7175
ServerCall.Listener<String> result = chain.interceptCall(mockCall, headers, mockHandler);
7276

7377
assertNotNull(result);
@@ -81,7 +85,7 @@ public void testMultipleSuccessfulInterceptors() {
8185
createTestInterceptor(30, false, null)
8286
);
8387

84-
GrpcInterceptorChain chain = new GrpcInterceptorChain(interceptors);
88+
GrpcInterceptorChain chain = new GrpcInterceptorChain(threadContext, interceptors);
8589
ServerCall.Listener<String> result = chain.interceptCall(mockCall, headers, mockHandler);
8690

8791
assertNotNull(result);
@@ -96,7 +100,7 @@ public void testFirstInterceptorFails() {
96100
createTestInterceptor(30, false, null)
97101
);
98102

99-
GrpcInterceptorChain chain = new GrpcInterceptorChain(interceptors);
103+
GrpcInterceptorChain chain = new GrpcInterceptorChain(threadContext, interceptors);
100104
chain.interceptCall(mockCall, headers, mockHandler);
101105

102106
verify(mockCall).close(
@@ -112,7 +116,7 @@ public void testMiddleInterceptorFails() {
112116
createTestInterceptor(30, false, null)
113117
);
114118

115-
GrpcInterceptorChain chain = new GrpcInterceptorChain(interceptors);
119+
GrpcInterceptorChain chain = new GrpcInterceptorChain(threadContext, interceptors);
116120
chain.interceptCall(mockCall, headers, mockHandler);
117121

118122
verify(mockCall).close(
@@ -128,7 +132,7 @@ public void testLastInterceptorFails() {
128132
createTestInterceptor(30, true, "Last failure")
129133
);
130134

131-
GrpcInterceptorChain chain = new GrpcInterceptorChain(interceptors);
135+
GrpcInterceptorChain chain = new GrpcInterceptorChain(threadContext, interceptors);
132136
chain.interceptCall(mockCall, headers, mockHandler);
133137

134138
verify(mockCall).close(
@@ -144,7 +148,7 @@ public void testInterceptorThrowsStatusRuntimeExceptionPermissionDenied() {
144148
createTestInterceptor(30, false, null)
145149
);
146150

147-
GrpcInterceptorChain chain = new GrpcInterceptorChain(interceptors);
151+
GrpcInterceptorChain chain = new GrpcInterceptorChain(threadContext, interceptors);
148152
ServerCall.Listener<String> result = chain.interceptCall(mockCall, headers, mockHandler);
149153

150154
assertNotNull(result);
@@ -160,7 +164,7 @@ public void testInterceptorThrowsStatusRuntimeExceptionUnauthenticated() {
160164
createTestInterceptor(20, false, null)
161165
);
162166

163-
GrpcInterceptorChain chain = new GrpcInterceptorChain(interceptors);
167+
GrpcInterceptorChain chain = new GrpcInterceptorChain(threadContext, interceptors);
164168
ServerCall.Listener<String> result = chain.interceptCall(mockCall, headers, mockHandler);
165169

166170
assertNotNull(result);
@@ -177,7 +181,7 @@ public void testInterceptorThrowsStatusRuntimeExceptionResourceExhausted() {
177181
createStatusRuntimeExceptionInterceptor(30, Status.RESOURCE_EXHAUSTED.withDescription("Rate limit exceeded"))
178182
);
179183

180-
GrpcInterceptorChain chain = new GrpcInterceptorChain(interceptors);
184+
GrpcInterceptorChain chain = new GrpcInterceptorChain(threadContext, interceptors);
181185
ServerCall.Listener<String> result = chain.interceptCall(mockCall, headers, mockHandler);
182186

183187
assertNotNull(result);
@@ -199,7 +203,7 @@ public void testInterceptorOrdering() {
199203
// Sort as GrpcPlugin would
200204
interceptors.sort((a, b) -> Integer.compare(a.order(), b.order()));
201205

202-
GrpcInterceptorChain chain = new GrpcInterceptorChain(interceptors);
206+
GrpcInterceptorChain chain = new GrpcInterceptorChain(threadContext, interceptors);
203207
chain.interceptCall(mockCall, headers, mockHandler);
204208

205209
// Verify execution order
@@ -221,7 +225,7 @@ public void testChainIntegrationWithRealScenario() {
221225
createLoggingInterceptor(30, "METRICS", executionLog)
222226
);
223227

224-
GrpcInterceptorChain chain = new GrpcInterceptorChain(interceptors);
228+
GrpcInterceptorChain chain = new GrpcInterceptorChain(threadContext, interceptors);
225229
chain.interceptCall(mockCall, headers, mockHandler);
226230

227231
assertEquals(Arrays.asList("AUTH", "LOGGING", "METRICS"), executionLog);
@@ -231,7 +235,7 @@ public void testChainIntegrationWithRealScenario() {
231235
* Generic test method that can be extended for different scenarios
232236
*/
233237
public void testChainWithPattern(List<OrderedGrpcInterceptor> interceptors, boolean expectSuccess, String expectedErrorMessage) {
234-
GrpcInterceptorChain chain = new GrpcInterceptorChain(interceptors);
238+
GrpcInterceptorChain chain = new GrpcInterceptorChain(threadContext, interceptors);
235239

236240
if (expectSuccess) {
237241
ServerCall.Listener<String> result = chain.interceptCall(mockCall, headers, mockHandler);

modules/transport-grpc/src/test/java/org/opensearch/transport/grpc/ssl/SecureNetty4GrpcServerTransportTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ public void setup() {
5252
Settings settings = Settings.builder().put("node.name", "test-node").put("grpc.netty.executor_count", 4).build();
5353
ExecutorBuilder<?> grpcExecutorBuilder = new FixedExecutorBuilder(settings, "grpc", 4, 1000, "thread_pool.grpc");
5454
threadPool = new ThreadPool(settings, grpcExecutorBuilder);
55-
serverInterceptor = new GrpcInterceptorChain(Collections.emptyList());
55+
serverInterceptor = new GrpcInterceptorChain(threadPool.getThreadContext(), Collections.emptyList());
5656
}
5757

5858
@After

0 commit comments

Comments
 (0)