Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create concept of persistent ThreadContext headers that are unstashable #8291

Merged
merged 14 commits into from
Jul 6, 2023
Merged
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- [Extensions] Support extension additional settings with extension REST initialization ([#8414](https://github.com/opensearch-project/OpenSearch/pull/8414))
- Adds mock implementation for TelemetryPlugin ([#7545](https://github.com/opensearch-project/OpenSearch/issues/7545))
- Support transport action names when registering NamedRoutes ([#7957](https://github.com/opensearch-project/OpenSearch/pull/7957))
- Create concept of persistent ThreadContext headers that are unstashable ([#8291]()https://github.com/opensearch-project/OpenSearch/pull/8291)

### Dependencies
- Bump `com.azure:azure-storage-common` from 12.21.0 to 12.21.1 (#7566, #7814)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ public StoredContext stashContext() {
* Otherwise when context is stash, it should be empty.
*/

ThreadContextStruct threadContextStruct = DEFAULT_CONTEXT;
ThreadContextStruct threadContextStruct = DEFAULT_CONTEXT.putPersistent(context.persistentHeaders);

if (context.requestHeaders.containsKey(Task.X_OPAQUE_ID)) {
threadContextStruct = threadContextStruct.putHeaders(
Expand Down Expand Up @@ -262,6 +262,7 @@ public StoredContext newStoredContext(boolean preserveResponseHeaders, Collectio
originalContext.requestHeaders,
originalContext.responseHeaders,
newTransientHeaders,
originalContext.persistentHeaders,
originalContext.isSystemContext,
originalContext.warningHeadersSize
);
Expand Down Expand Up @@ -337,7 +338,7 @@ public void setHeaders(Tuple<Map<String, String>, Map<String, Set<String>>> head
if (requestHeaders.isEmpty() && responseHeaders.isEmpty()) {
struct = ThreadContextStruct.EMPTY;
} else {
struct = new ThreadContextStruct(requestHeaders, responseHeaders, Collections.emptyMap(), false);
struct = new ThreadContextStruct(requestHeaders, responseHeaders, Collections.emptyMap(), Collections.emptyMap(), false);
}
threadLocal.set(struct);
}
Expand Down Expand Up @@ -375,6 +376,13 @@ public String getHeader(String key) {
return value;
}

/**
* Returns the persistent header for the given key or <code>null</code> if not present - persistent headers cannot be stashed
*/
public Object getPersistent(String key) {
dblock marked this conversation as resolved.
Show resolved Hide resolved
return threadLocal.get().persistentHeaders.get(key);
}

/**
* Returns all of the request headers from the thread's context.<br>
* <b>Be advised, headers might contain credentials.</b>
Expand Down Expand Up @@ -434,6 +442,20 @@ public void putHeader(Map<String, String> header) {
threadLocal.set(threadLocal.get().putHeaders(header));
}

/**
* Puts a persistent header into the context - persistent headers cannot be stashed
*/
public void putPersistent(String key, Object value) {
threadLocal.set(threadLocal.get().putPersistent(key, value));
}

/**
* Puts all of the given headers into this persistent context - persistent headers cannot be stashed
*/
public void putPersistent(Map<String, Object> persistentHeaders) {
threadLocal.set(threadLocal.get().putPersistent(persistentHeaders));
}

/**
* Puts a transient header object into this context
*/
Expand Down Expand Up @@ -566,12 +588,14 @@ private static final class ThreadContextStruct {
Collections.emptyMap(),
Collections.emptyMap(),
Collections.emptyMap(),
Collections.emptyMap(),
false
);

private final Map<String, String> requestHeaders;
private final Map<String, Object> transientHeaders;
private final Map<String, Set<String>> responseHeaders;
private final Map<String, Object> persistentHeaders;
dblock marked this conversation as resolved.
Show resolved Hide resolved
private final boolean isSystemContext;
// saving current warning headers' size not to recalculate the size with every new warning header
private final long warningHeadersSize;
Expand All @@ -580,18 +604,20 @@ private ThreadContextStruct setSystemContext() {
if (isSystemContext) {
return this;
}
return new ThreadContextStruct(requestHeaders, responseHeaders, transientHeaders, true);
return new ThreadContextStruct(requestHeaders, responseHeaders, transientHeaders, persistentHeaders, true);
}

private ThreadContextStruct(
Map<String, String> requestHeaders,
Map<String, Set<String>> responseHeaders,
Map<String, Object> transientHeaders,
Map<String, Object> persistentHeaders,
boolean isSystemContext
) {
this.requestHeaders = requestHeaders;
this.responseHeaders = responseHeaders;
this.transientHeaders = transientHeaders;
this.persistentHeaders = persistentHeaders;
this.isSystemContext = isSystemContext;
this.warningHeadersSize = 0L;
}
Expand All @@ -600,12 +626,14 @@ private ThreadContextStruct(
Map<String, String> requestHeaders,
Map<String, Set<String>> responseHeaders,
Map<String, Object> transientHeaders,
Map<String, Object> persistentHeaders,
boolean isSystemContext,
long warningHeadersSize
) {
this.requestHeaders = requestHeaders;
this.responseHeaders = responseHeaders;
this.transientHeaders = transientHeaders;
this.persistentHeaders = persistentHeaders;
this.isSystemContext = isSystemContext;
this.warningHeadersSize = warningHeadersSize;
}
Expand All @@ -614,13 +642,13 @@ private ThreadContextStruct(
* This represents the default context and it should only ever be called by {@link #DEFAULT_CONTEXT}.
*/
private ThreadContextStruct() {
this(Collections.emptyMap(), Collections.emptyMap(), Collections.emptyMap(), false);
this(Collections.emptyMap(), Collections.emptyMap(), Collections.emptyMap(), Collections.emptyMap(), false);
}

private ThreadContextStruct putRequest(String key, String value) {
Map<String, String> newRequestHeaders = new HashMap<>(this.requestHeaders);
putSingleHeader(key, value, newRequestHeaders);
return new ThreadContextStruct(newRequestHeaders, responseHeaders, transientHeaders, isSystemContext);
return new ThreadContextStruct(newRequestHeaders, responseHeaders, transientHeaders, persistentHeaders, isSystemContext);
}

private static <T> void putSingleHeader(String key, T value, Map<String, T> newHeaders) {
Expand All @@ -637,7 +665,25 @@ private ThreadContextStruct putHeaders(Map<String, String> headers) {
for (Map.Entry<String, String> entry : headers.entrySet()) {
putSingleHeader(entry.getKey(), entry.getValue(), newHeaders);
}
return new ThreadContextStruct(newHeaders, responseHeaders, transientHeaders, isSystemContext);
return new ThreadContextStruct(newHeaders, responseHeaders, transientHeaders, persistentHeaders, isSystemContext);
}
}

private ThreadContextStruct putPersistent(String key, Object value) {
Map<String, Object> newPersistentHeaders = new HashMap<>(this.persistentHeaders);
putSingleHeader(key, value, newPersistentHeaders);
return new ThreadContextStruct(requestHeaders, responseHeaders, transientHeaders, newPersistentHeaders, isSystemContext);
}

private ThreadContextStruct putPersistent(Map<String, Object> headers) {
if (headers.isEmpty()) {
return this;
} else {
final Map<String, Object> newPersistentHeaders = new HashMap<>(this.persistentHeaders);
for (Map.Entry<String, Object> entry : headers.entrySet()) {
putSingleHeader(entry.getKey(), entry.getValue(), newPersistentHeaders);
}
return new ThreadContextStruct(requestHeaders, responseHeaders, transientHeaders, newPersistentHeaders, isSystemContext);
}
}

Expand All @@ -658,7 +704,7 @@ private ThreadContextStruct putResponseHeaders(Map<String, Set<String>> headers)
newResponseHeaders.put(key, entry.getValue());
}
}
return new ThreadContextStruct(requestHeaders, newResponseHeaders, transientHeaders, isSystemContext);
return new ThreadContextStruct(requestHeaders, newResponseHeaders, transientHeaders, persistentHeaders, isSystemContext);
}

private ThreadContextStruct putResponse(
Expand Down Expand Up @@ -695,6 +741,7 @@ private ThreadContextStruct putResponse(
requestHeaders,
responseHeaders,
transientHeaders,
persistentHeaders,
isSystemContext,
newWarningHeaderSize
);
Expand Down Expand Up @@ -730,21 +777,28 @@ private ThreadContextStruct putResponse(
return this;
}
}
return new ThreadContextStruct(requestHeaders, newResponseHeaders, transientHeaders, isSystemContext, newWarningHeaderSize);
return new ThreadContextStruct(
requestHeaders,
newResponseHeaders,
transientHeaders,
persistentHeaders,
isSystemContext,
newWarningHeaderSize
);
}

private ThreadContextStruct putTransient(Map<String, Object> values) {
Map<String, Object> newTransient = new HashMap<>(this.transientHeaders);
for (Map.Entry<String, Object> entry : values.entrySet()) {
putSingleHeader(entry.getKey(), entry.getValue(), newTransient);
}
return new ThreadContextStruct(requestHeaders, responseHeaders, newTransient, isSystemContext);
return new ThreadContextStruct(requestHeaders, responseHeaders, newTransient, persistentHeaders, isSystemContext);
}

private ThreadContextStruct putTransient(String key, Object value) {
Map<String, Object> newTransient = new HashMap<>(this.transientHeaders);
putSingleHeader(key, value, newTransient);
return new ThreadContextStruct(requestHeaders, responseHeaders, newTransient, isSystemContext);
return new ThreadContextStruct(requestHeaders, responseHeaders, newTransient, persistentHeaders, isSystemContext);
}

private ThreadContextStruct copyHeaders(Iterable<Map.Entry<String, String>> headers) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,35 @@ public void testStashContext() {
assertEquals("1", threadContext.getHeader("default"));
}

public void testStashContextWithPersistentHeaders() {
Settings build = Settings.builder().put("request.headers.default", "1").build();
ThreadContext threadContext = new ThreadContext(build);
threadContext.putHeader("foo", "bar");
threadContext.putTransient("ctx.foo", 1);
threadContext.putPersistent("persistent_foo", "baz");
threadContext.putPersistent("ctx.persistent_foo", 10);
assertEquals("bar", threadContext.getHeader("foo"));
assertEquals(Integer.valueOf(1), threadContext.getTransient("ctx.foo"));
assertEquals("1", threadContext.getHeader("default"));
try (ThreadContext.StoredContext ctx = threadContext.stashContext()) {
assertNull(threadContext.getHeader("foo"));
assertNull(threadContext.getTransient("ctx.foo"));
assertEquals("1", threadContext.getHeader("default"));
dblock marked this conversation as resolved.
Show resolved Hide resolved

assertEquals("baz", threadContext.getPersistent("persistent_foo"));
assertEquals(Integer.valueOf(10), threadContext.getPersistent("ctx.persistent_foo"));
assertNull(threadContext.getPersistent("default"));
}

assertEquals("bar", threadContext.getHeader("foo"));
assertEquals(Integer.valueOf(1), threadContext.getTransient("ctx.foo"));
assertEquals("1", threadContext.getHeader("default"));

assertEquals("baz", threadContext.getPersistent("persistent_foo"));
assertEquals(Integer.valueOf(10), threadContext.getPersistent("ctx.persistent_foo"));
assertNull(threadContext.getPersistent("default"));
}

public void testNewContextWithClearedTransients() {
ThreadContext threadContext = new ThreadContext(Settings.EMPTY);
threadContext.putTransient("foo", "bar");
Expand Down