diff --git a/CHANGELOG.md b/CHANGELOG.md index 619bb0d58d8ae..e0b9cc8a4d642 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -99,6 +99,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - [Extensions] Support extension additional settings with extension REST initialization ([#8414](https://github.com/opensearch-project/OpenSearch/pull/8414)) - Adds mock implementation for TelemetryPlugin ([#7545](https://github.com/opensearch-project/OpenSearch/issues/7545)) - Support transport action names when registering NamedRoutes ([#7957](https://github.com/opensearch-project/OpenSearch/pull/7957)) +- Create concept of persistent ThreadContext headers that are unstashable ([#8291]()https://github.com/opensearch-project/OpenSearch/pull/8291) ### Dependencies - Bump `com.azure:azure-storage-common` from 12.21.0 to 12.21.1 (#7566, #7814) diff --git a/server/src/main/java/org/opensearch/common/util/concurrent/ThreadContext.java b/server/src/main/java/org/opensearch/common/util/concurrent/ThreadContext.java index 025fb7a36b684..9dd5d21a00231 100644 --- a/server/src/main/java/org/opensearch/common/util/concurrent/ThreadContext.java +++ b/server/src/main/java/org/opensearch/common/util/concurrent/ThreadContext.java @@ -149,7 +149,7 @@ public StoredContext stashContext() { * Otherwise when context is stash, it should be empty. */ - ThreadContextStruct threadContextStruct = DEFAULT_CONTEXT; + ThreadContextStruct threadContextStruct = DEFAULT_CONTEXT.putPersistent(context.persistentHeaders); if (context.requestHeaders.containsKey(Task.X_OPAQUE_ID)) { threadContextStruct = threadContextStruct.putHeaders( @@ -262,6 +262,7 @@ public StoredContext newStoredContext(boolean preserveResponseHeaders, Collectio originalContext.requestHeaders, originalContext.responseHeaders, newTransientHeaders, + originalContext.persistentHeaders, originalContext.isSystemContext, originalContext.warningHeadersSize ); @@ -337,7 +338,7 @@ public void setHeaders(Tuple, Map>> head if (requestHeaders.isEmpty() && responseHeaders.isEmpty()) { struct = ThreadContextStruct.EMPTY; } else { - struct = new ThreadContextStruct(requestHeaders, responseHeaders, Collections.emptyMap(), false); + struct = new ThreadContextStruct(requestHeaders, responseHeaders, Collections.emptyMap(), Collections.emptyMap(), false); } threadLocal.set(struct); } @@ -375,6 +376,13 @@ public String getHeader(String key) { return value; } + /** + * Returns the persistent header for the given key or null if not present - persistent headers cannot be stashed + */ + public Object getPersistent(String key) { + return threadLocal.get().persistentHeaders.get(key); + } + /** * Returns all of the request headers from the thread's context.
* Be advised, headers might contain credentials. @@ -434,6 +442,20 @@ public void putHeader(Map header) { threadLocal.set(threadLocal.get().putHeaders(header)); } + /** + * Puts a persistent header into the context - persistent headers cannot be stashed + */ + public void putPersistent(String key, Object value) { + threadLocal.set(threadLocal.get().putPersistent(key, value)); + } + + /** + * Puts all of the given headers into this persistent context - persistent headers cannot be stashed + */ + public void putPersistent(Map persistentHeaders) { + threadLocal.set(threadLocal.get().putPersistent(persistentHeaders)); + } + /** * Puts a transient header object into this context */ @@ -566,12 +588,14 @@ private static final class ThreadContextStruct { Collections.emptyMap(), Collections.emptyMap(), Collections.emptyMap(), + Collections.emptyMap(), false ); private final Map requestHeaders; private final Map transientHeaders; private final Map> responseHeaders; + private final Map persistentHeaders; private final boolean isSystemContext; // saving current warning headers' size not to recalculate the size with every new warning header private final long warningHeadersSize; @@ -580,18 +604,20 @@ private ThreadContextStruct setSystemContext() { if (isSystemContext) { return this; } - return new ThreadContextStruct(requestHeaders, responseHeaders, transientHeaders, true); + return new ThreadContextStruct(requestHeaders, responseHeaders, transientHeaders, persistentHeaders, true); } private ThreadContextStruct( Map requestHeaders, Map> responseHeaders, Map transientHeaders, + Map persistentHeaders, boolean isSystemContext ) { this.requestHeaders = requestHeaders; this.responseHeaders = responseHeaders; this.transientHeaders = transientHeaders; + this.persistentHeaders = persistentHeaders; this.isSystemContext = isSystemContext; this.warningHeadersSize = 0L; } @@ -600,12 +626,14 @@ private ThreadContextStruct( Map requestHeaders, Map> responseHeaders, Map transientHeaders, + Map persistentHeaders, boolean isSystemContext, long warningHeadersSize ) { this.requestHeaders = requestHeaders; this.responseHeaders = responseHeaders; this.transientHeaders = transientHeaders; + this.persistentHeaders = persistentHeaders; this.isSystemContext = isSystemContext; this.warningHeadersSize = warningHeadersSize; } @@ -614,13 +642,13 @@ private ThreadContextStruct( * This represents the default context and it should only ever be called by {@link #DEFAULT_CONTEXT}. */ private ThreadContextStruct() { - this(Collections.emptyMap(), Collections.emptyMap(), Collections.emptyMap(), false); + this(Collections.emptyMap(), Collections.emptyMap(), Collections.emptyMap(), Collections.emptyMap(), false); } private ThreadContextStruct putRequest(String key, String value) { Map newRequestHeaders = new HashMap<>(this.requestHeaders); putSingleHeader(key, value, newRequestHeaders); - return new ThreadContextStruct(newRequestHeaders, responseHeaders, transientHeaders, isSystemContext); + return new ThreadContextStruct(newRequestHeaders, responseHeaders, transientHeaders, persistentHeaders, isSystemContext); } private static void putSingleHeader(String key, T value, Map newHeaders) { @@ -637,7 +665,25 @@ private ThreadContextStruct putHeaders(Map headers) { for (Map.Entry entry : headers.entrySet()) { putSingleHeader(entry.getKey(), entry.getValue(), newHeaders); } - return new ThreadContextStruct(newHeaders, responseHeaders, transientHeaders, isSystemContext); + return new ThreadContextStruct(newHeaders, responseHeaders, transientHeaders, persistentHeaders, isSystemContext); + } + } + + private ThreadContextStruct putPersistent(String key, Object value) { + Map newPersistentHeaders = new HashMap<>(this.persistentHeaders); + putSingleHeader(key, value, newPersistentHeaders); + return new ThreadContextStruct(requestHeaders, responseHeaders, transientHeaders, newPersistentHeaders, isSystemContext); + } + + private ThreadContextStruct putPersistent(Map headers) { + if (headers.isEmpty()) { + return this; + } else { + final Map newPersistentHeaders = new HashMap<>(this.persistentHeaders); + for (Map.Entry entry : headers.entrySet()) { + putSingleHeader(entry.getKey(), entry.getValue(), newPersistentHeaders); + } + return new ThreadContextStruct(requestHeaders, responseHeaders, transientHeaders, newPersistentHeaders, isSystemContext); } } @@ -658,7 +704,7 @@ private ThreadContextStruct putResponseHeaders(Map> headers) newResponseHeaders.put(key, entry.getValue()); } } - return new ThreadContextStruct(requestHeaders, newResponseHeaders, transientHeaders, isSystemContext); + return new ThreadContextStruct(requestHeaders, newResponseHeaders, transientHeaders, persistentHeaders, isSystemContext); } private ThreadContextStruct putResponse( @@ -695,6 +741,7 @@ private ThreadContextStruct putResponse( requestHeaders, responseHeaders, transientHeaders, + persistentHeaders, isSystemContext, newWarningHeaderSize ); @@ -730,7 +777,14 @@ private ThreadContextStruct putResponse( return this; } } - return new ThreadContextStruct(requestHeaders, newResponseHeaders, transientHeaders, isSystemContext, newWarningHeaderSize); + return new ThreadContextStruct( + requestHeaders, + newResponseHeaders, + transientHeaders, + persistentHeaders, + isSystemContext, + newWarningHeaderSize + ); } private ThreadContextStruct putTransient(Map values) { @@ -738,13 +792,13 @@ private ThreadContextStruct putTransient(Map values) { for (Map.Entry entry : values.entrySet()) { putSingleHeader(entry.getKey(), entry.getValue(), newTransient); } - return new ThreadContextStruct(requestHeaders, responseHeaders, newTransient, isSystemContext); + return new ThreadContextStruct(requestHeaders, responseHeaders, newTransient, persistentHeaders, isSystemContext); } private ThreadContextStruct putTransient(String key, Object value) { Map newTransient = new HashMap<>(this.transientHeaders); putSingleHeader(key, value, newTransient); - return new ThreadContextStruct(requestHeaders, responseHeaders, newTransient, isSystemContext); + return new ThreadContextStruct(requestHeaders, responseHeaders, newTransient, persistentHeaders, isSystemContext); } private ThreadContextStruct copyHeaders(Iterable> headers) { diff --git a/server/src/test/java/org/opensearch/common/util/concurrent/ThreadContextTests.java b/server/src/test/java/org/opensearch/common/util/concurrent/ThreadContextTests.java index 64286e47b4966..dfa239757513e 100644 --- a/server/src/test/java/org/opensearch/common/util/concurrent/ThreadContextTests.java +++ b/server/src/test/java/org/opensearch/common/util/concurrent/ThreadContextTests.java @@ -71,6 +71,35 @@ public void testStashContext() { assertEquals("1", threadContext.getHeader("default")); } + public void testStashContextWithPersistentHeaders() { + Settings build = Settings.builder().put("request.headers.default", "1").build(); + ThreadContext threadContext = new ThreadContext(build); + threadContext.putHeader("foo", "bar"); + threadContext.putTransient("ctx.foo", 1); + threadContext.putPersistent("persistent_foo", "baz"); + threadContext.putPersistent("ctx.persistent_foo", 10); + assertEquals("bar", threadContext.getHeader("foo")); + assertEquals(Integer.valueOf(1), threadContext.getTransient("ctx.foo")); + assertEquals("1", threadContext.getHeader("default")); + try (ThreadContext.StoredContext ctx = threadContext.stashContext()) { + assertNull(threadContext.getHeader("foo")); + assertNull(threadContext.getTransient("ctx.foo")); + assertEquals("1", threadContext.getHeader("default")); + + assertEquals("baz", threadContext.getPersistent("persistent_foo")); + assertEquals(Integer.valueOf(10), threadContext.getPersistent("ctx.persistent_foo")); + assertNull(threadContext.getPersistent("default")); + } + + assertEquals("bar", threadContext.getHeader("foo")); + assertEquals(Integer.valueOf(1), threadContext.getTransient("ctx.foo")); + assertEquals("1", threadContext.getHeader("default")); + + assertEquals("baz", threadContext.getPersistent("persistent_foo")); + assertEquals(Integer.valueOf(10), threadContext.getPersistent("ctx.persistent_foo")); + assertNull(threadContext.getPersistent("default")); + } + public void testNewContextWithClearedTransients() { ThreadContext threadContext = new ThreadContext(Settings.EMPTY); threadContext.putTransient("foo", "bar");