diff --git a/dd-java-agent/appsec/src/jmh/java/datadog/appsec/benchmark/AppSecBenchmark.java b/dd-java-agent/appsec/src/jmh/java/datadog/appsec/benchmark/AppSecBenchmark.java
index bf2fe4bff78..e1c8c5a9356 100644
--- a/dd-java-agent/appsec/src/jmh/java/datadog/appsec/benchmark/AppSecBenchmark.java
+++ b/dd-java-agent/appsec/src/jmh/java/datadog/appsec/benchmark/AppSecBenchmark.java
@@ -247,6 +247,14 @@ public BlockResponseFunction getBlockResponseFunction() {
return null;
}
+ @Override
+ public void setRequiresPostProcessing(boolean postProcessing) {}
+
+ @Override
+ public boolean isRequiresPostProcessing() {
+ return false;
+ }
+
@Override
public void close() throws IOException {}
}
diff --git a/dd-java-agent/appsec/src/main/java/com/datadog/appsec/AppSecSystem.java b/dd-java-agent/appsec/src/main/java/com/datadog/appsec/AppSecSystem.java
index db43c401f0d..5771efe0587 100644
--- a/dd-java-agent/appsec/src/main/java/com/datadog/appsec/AppSecSystem.java
+++ b/dd-java-agent/appsec/src/main/java/com/datadog/appsec/AppSecSystem.java
@@ -1,6 +1,5 @@
package com.datadog.appsec;
-import com.datadog.appsec.api.security.ApiSecurityRequestSampler;
import com.datadog.appsec.blocking.BlockingServiceImpl;
import com.datadog.appsec.config.AppSecConfigService;
import com.datadog.appsec.config.AppSecConfigServiceImpl;
@@ -77,15 +76,12 @@ private static void doStart(SubscriptionService gw, SharedCommunicationObjects s
sco.createRemaining(config);
RateLimiter rateLimiter = getRateLimiter(config, sco.monitoring);
- ApiSecurityRequestSampler requestSampler =
- new ApiSecurityRequestSampler(config, configurationPoller);
GatewayBridge gatewayBridge =
new GatewayBridge(
gw,
REPLACEABLE_EVENT_PRODUCER,
rateLimiter,
- requestSampler,
APP_SEC_CONFIG_SERVICE.getTraceSegmentPostProcessors());
loadModules(eventDispatcher);
diff --git a/dd-java-agent/appsec/src/main/java/com/datadog/appsec/api/security/ApiAccessTracker.java b/dd-java-agent/appsec/src/main/java/com/datadog/appsec/api/security/ApiAccessTracker.java
new file mode 100644
index 00000000000..d63c70f918e
--- /dev/null
+++ b/dd-java-agent/appsec/src/main/java/com/datadog/appsec/api/security/ApiAccessTracker.java
@@ -0,0 +1,79 @@
+package com.datadog.appsec.api.security;
+
+import java.util.LinkedHashMap;
+
+/**
+ * The ApiAccessTracker class provides a mechanism to track API access events, managing them within
+ * a specified capacity limit. Each event is associated with a unique combination of route, method,
+ * and status code, which is used to generate a unique key for tracking access timestamps.
+ *
+ *
Usage: - When an API access event occurs, the `updateApiAccessIfExpired` method is called with
+ * the route, method, and status code of the API request. - If the access event for the given
+ * parameters is new or has expired (based on the expirationTimeInMs threshold), the event's
+ * timestamp is updated, effectively moving the event to the end of the tracking list. - If the
+ * tracker's capacity is reached, the oldest event is automatically removed to make room for new
+ * events. - This mechanism ensures that the tracker always contains the most recent access events
+ * within the specified capacity limit, with older, less relevant events being discarded.
+ */
+public class ApiAccessTracker {
+
+ private static final int INTERVAL_SECONDS = 30;
+ private static final int MAX_SIZE = 4096;
+ private final LinkedHashMap apiAccessLog; // Map
+ private final int capacity;
+ private final long expirationTimeInMs;
+
+ public ApiAccessTracker() {
+ this(MAX_SIZE, INTERVAL_SECONDS * 1000);
+ }
+
+ public ApiAccessTracker(int capacity, long expirationTimeInMs) {
+ this.capacity = capacity;
+ this.expirationTimeInMs = expirationTimeInMs;
+ this.apiAccessLog = new LinkedHashMap<>();
+ }
+
+ /**
+ * Updates the API access log with the given route, method, and status code. If the record exists
+ * and is outdated, it is updated by moving to the end of the list. If the record does not exist,
+ * a new record is added. If the capacity limit is reached, the oldest record is removed. Returns
+ * true if the record was updated or added, false otherwise.
+ *
+ * @param route
+ * @param method
+ * @param statusCode
+ * @return return true if the record was updated or added, false otherwise
+ */
+ public boolean updateApiAccessIfExpired(String route, String method, int statusCode) {
+ long currentTime = System.currentTimeMillis();
+ long hash = computeApiHash(route, method, statusCode);
+
+ // If the record exists and is outdated, update it by moving to the end of the list
+ if (apiAccessLog.containsKey(hash)) {
+ long lastAccessTime = apiAccessLog.get(hash);
+ if (currentTime - lastAccessTime > expirationTimeInMs) {
+ // Remove and add the record to update the timestamp and move it to the end of the list
+ apiAccessLog.remove(hash);
+ apiAccessLog.put(hash, currentTime);
+ return true;
+ }
+ return false;
+ } else {
+ // If the record does not exist, just add a new one
+ if (apiAccessLog.size() >= capacity) {
+ // Remove the oldest record if the capacity limit is reached
+ apiAccessLog.remove(apiAccessLog.keySet().iterator().next());
+ }
+ apiAccessLog.put(hash, currentTime);
+ return true;
+ }
+ }
+
+ private long computeApiHash(String route, String method, int statusCode) {
+ long result = 17;
+ result = 31 * result + route.hashCode();
+ result = 31 * result + method.hashCode();
+ result = 31 * result + statusCode;
+ return result;
+ }
+}
diff --git a/dd-java-agent/appsec/src/main/java/com/datadog/appsec/api/security/ApiSecurityRequestSampler.java b/dd-java-agent/appsec/src/main/java/com/datadog/appsec/api/security/ApiSecurityRequestSampler.java
deleted file mode 100644
index 4ebb8fde35b..00000000000
--- a/dd-java-agent/appsec/src/main/java/com/datadog/appsec/api/security/ApiSecurityRequestSampler.java
+++ /dev/null
@@ -1,78 +0,0 @@
-package com.datadog.appsec.api.security;
-
-import static datadog.remoteconfig.tuf.RemoteConfigRequest.ClientInfo.CAPABILITY_ASM_API_SECURITY_SAMPLE_RATE;
-
-import com.datadog.appsec.config.AppSecFeaturesDeserializer;
-import datadog.remoteconfig.ConfigurationPoller;
-import datadog.remoteconfig.Product;
-import datadog.trace.api.Config;
-import java.util.concurrent.atomic.AtomicLong;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-public class ApiSecurityRequestSampler {
-
- private static final Logger log = LoggerFactory.getLogger(ApiSecurityRequestSampler.class);
-
- private volatile int sampling;
- private final AtomicLong cumulativeCounter = new AtomicLong();
-
- public ApiSecurityRequestSampler(final Config config) {
- sampling = computeSamplingParameter(config.getApiSecurityRequestSampleRate());
- }
-
- public ApiSecurityRequestSampler(final Config config, ConfigurationPoller configurationPoller) {
- this(config);
- if (configurationPoller == null) {
- return;
- }
-
- configurationPoller.addListener(
- Product.ASM_FEATURES,
- "asm_api_security",
- AppSecFeaturesDeserializer.INSTANCE,
- (configKey, newConfig, pollingRateHinter) -> {
- if (newConfig != null && newConfig.apiSecurity != null) {
- Float newSamplingFloat = newConfig.apiSecurity.requestSampleRate;
- if (newSamplingFloat != null) {
- int newSampling = computeSamplingParameter(newSamplingFloat);
- if (newSampling != sampling) {
- sampling = newSampling;
- cumulativeCounter.set(0); // Reset current sampling counter
- if (sampling == 0) {
- log.info("Api Security is disabled via remote-config");
- } else {
- log.info(
- "Api Security changed via remote-config. New sampling rate is {}% of all requests.",
- sampling);
- }
- }
- }
- }
- });
- configurationPoller.addCapabilities(CAPABILITY_ASM_API_SECURITY_SAMPLE_RATE);
- }
-
- public boolean sampleRequest() {
- long prevValue = cumulativeCounter.getAndAdd(sampling);
- long newValue = prevValue + sampling;
- if (newValue / 100 == prevValue / 100 + 1) {
- // Sample request
- return true;
- }
- // Skipped by sampling
- return false;
- }
-
- static int computeSamplingParameter(final float pct) {
- if (pct >= 1) {
- return 100;
- }
- if (pct < 0) {
- // We don't support disabling Api Security by setting it, so we set it to 100%.
- // TODO: We probably want a warning here.
- return 100;
- }
- return (int) (pct * 100);
- }
-}
diff --git a/dd-java-agent/appsec/src/main/java/com/datadog/appsec/config/AppSecFeatures.java b/dd-java-agent/appsec/src/main/java/com/datadog/appsec/config/AppSecFeatures.java
index cba6d679e98..fb62ab44e4e 100644
--- a/dd-java-agent/appsec/src/main/java/com/datadog/appsec/config/AppSecFeatures.java
+++ b/dd-java-agent/appsec/src/main/java/com/datadog/appsec/config/AppSecFeatures.java
@@ -3,15 +3,7 @@
public class AppSecFeatures {
public Asm asm;
- @com.squareup.moshi.Json(name = "api_security")
- public ApiSecurity apiSecurity;
-
public static class Asm {
public boolean enabled;
}
-
- public static class ApiSecurity {
- @com.squareup.moshi.Json(name = "request_sample_rate")
- public Float requestSampleRate;
- }
}
diff --git a/dd-java-agent/appsec/src/main/java/com/datadog/appsec/gateway/GatewayBridge.java b/dd-java-agent/appsec/src/main/java/com/datadog/appsec/gateway/GatewayBridge.java
index 29ec03ff35b..cf390416700 100644
--- a/dd-java-agent/appsec/src/main/java/com/datadog/appsec/gateway/GatewayBridge.java
+++ b/dd-java-agent/appsec/src/main/java/com/datadog/appsec/gateway/GatewayBridge.java
@@ -3,7 +3,7 @@
import static com.datadog.appsec.event.data.MapDataBundle.Builder.CAPACITY_6_10;
import com.datadog.appsec.AppSecSystem;
-import com.datadog.appsec.api.security.ApiSecurityRequestSampler;
+import com.datadog.appsec.api.security.ApiAccessTracker;
import com.datadog.appsec.config.TraceSegmentPostProcessor;
import com.datadog.appsec.event.EventProducerService;
import com.datadog.appsec.event.EventProducerService.DataSubscriberInfo;
@@ -29,6 +29,7 @@
import datadog.trace.api.http.StoredBodySupplier;
import datadog.trace.api.internal.TraceSegment;
import datadog.trace.api.telemetry.WafMetricCollector;
+import datadog.trace.bootstrap.instrumentation.api.AgentSpan;
import datadog.trace.bootstrap.instrumentation.api.Tags;
import datadog.trace.bootstrap.instrumentation.api.URIDataAdapter;
import datadog.trace.util.Strings;
@@ -60,10 +61,10 @@ public class GatewayBridge {
private static final Pattern QUERY_PARAM_SPLITTER = Pattern.compile("&");
private static final Map> EMPTY_QUERY_PARAMS = Collections.emptyMap();
+ private final ApiAccessTracker apiAccessTracker = new ApiAccessTracker();
private final SubscriptionService subscriptionService;
private final EventProducerService producerService;
private final RateLimiter rateLimiter;
- private final ApiSecurityRequestSampler requestSampler;
private final List traceSegmentPostProcessors;
// subscriber cache
@@ -80,12 +81,10 @@ public GatewayBridge(
SubscriptionService subscriptionService,
EventProducerService producerService,
RateLimiter rateLimiter,
- ApiSecurityRequestSampler requestSampler,
List traceSegmentPostProcessors) {
this.subscriptionService = subscriptionService;
this.producerService = producerService;
this.rateLimiter = rateLimiter;
- this.requestSampler = requestSampler;
this.traceSegmentPostProcessors = traceSegmentPostProcessors;
}
@@ -112,11 +111,6 @@ public void init() {
return NoopFlow.INSTANCE;
}
- maybeExtractSchemas(ctx);
-
- // WAF call
- ctx.closeAdditive();
-
TraceSegment traceSeg = ctx_.getTraceSegment();
// AppSec report metric and events for web span only
@@ -175,11 +169,6 @@ public void init() {
}
}
- // If extracted any Api Schemas - commit them
- if (!ctx.commitApiSchemas(traceSeg)) {
- log.debug("Unable to commit, api security schemas and will be skipped");
- }
-
if (ctx.isBlocked()) {
WafMetricCollector.get().wafRequestBlocked();
} else if (!collectedEvents.isEmpty()) {
@@ -189,7 +178,7 @@ public void init() {
}
}
- ctx.close();
+ // ctx.close();
return NoopFlow.INSTANCE;
});
@@ -234,6 +223,7 @@ public void init() {
DataBundle bundle =
new SingletonDataBundle<>(KnownAddresses.REQUEST_PATH_PARAMS, data);
try {
+ ctx_.setRequiresPostProcessing(true);
return producerService.publishDataEvent(subInfo, ctx, bundle, false);
} catch (ExpiredSubscriberInfoException e) {
pathParamsSubInfo = null;
@@ -305,6 +295,7 @@ public void init() {
new SingletonDataBundle<>(
KnownAddresses.REQUEST_BODY_OBJECT, ObjectIntrospection.convert(obj));
try {
+ ctx_.setRequiresPostProcessing(true);
return producerService.publishDataEvent(subInfo, ctx, bundle, false);
} catch (ExpiredSubscriberInfoException e) {
requestBodySubInfo = null;
@@ -322,7 +313,7 @@ public void init() {
}
ctx.setPeerAddress(ip);
ctx.setPeerPort(port);
- return maybePublishRequestData(ctx);
+ return maybePublishRequestData(ctx_, ctx);
});
subscriptionService.registerCallback(
@@ -419,6 +410,27 @@ public void init() {
}
}
});
+
+ subscriptionService.registerCallback(
+ EVENTS.postProcessing(),
+ (ctx_, span) -> {
+ AppSecRequestContext ctx = ctx_.getData(RequestContextSlot.APPSEC);
+ if (ctx == null) {
+ return;
+ }
+
+ maybeExtractSchemas(ctx, span);
+ ctx.closeAdditive();
+
+ TraceSegment traceSeg = ctx_.getTraceSegment();
+
+ if (traceSeg != null) {
+ // If extracted any Api Schemas - commit them
+ if (!ctx.commitApiSchemas(traceSeg)) {
+ log.debug("Unable to commit, api security schemas and will be skipped");
+ }
+ }
+ });
}
public void stop() {
@@ -474,7 +486,7 @@ public Flow apply(RequestContext ctx_) {
return NoopFlow.INSTANCE;
}
ctx.finishRequestHeaders();
- return maybePublishRequestData(ctx);
+ return maybePublishRequestData(ctx_, ctx);
}
}
@@ -510,11 +522,11 @@ public Flow apply(RequestContext ctx_, String method, URIDataAdapter uri)
log.debug("Failed to encode URI '{}{}'", uri.path(), uri.query());
}
}
- return maybePublishRequestData(ctx);
+ return maybePublishRequestData(ctx_, ctx);
}
}
- private Flow maybePublishRequestData(AppSecRequestContext ctx) {
+ private Flow maybePublishRequestData(RequestContext reqCtx, AppSecRequestContext ctx) {
String savedRawURI = ctx.getSavedRawURI();
if (savedRawURI == null || !ctx.isFinishedRequestHeaders() || ctx.getPeerAddress() == null) {
@@ -568,6 +580,7 @@ private Flow maybePublishRequestData(AppSecRequestContext ctx) {
}
try {
+ reqCtx.setRequiresPostProcessing(true);
return producerService.publishDataEvent(subInfo, ctx, bundle, false);
} catch (ExpiredSubscriberInfoException e) {
this.initialReqDataSubInfo = null;
@@ -607,13 +620,27 @@ private Flow maybePublishResponseData(AppSecRequestContext ctx) {
}
}
- private void maybeExtractSchemas(AppSecRequestContext ctx) {
- boolean extractSchema = false;
- if (Config.get().isApiSecurityEnabled() && requestSampler != null) {
- extractSchema = requestSampler.sampleRequest();
+ private void maybeExtractSchemas(AppSecRequestContext ctx, AgentSpan span) {
+ boolean extractSchema = Config.get().isApiSecurityEnabled();
+ if (!extractSchema) {
+ return;
}
- if (!extractSchema) {
+ Object routeObj = span.getTag(Tags.HTTP_ROUTE);
+ String route = routeObj instanceof String ? (String) routeObj : null;
+
+ Object methodObj = span.getTag(Tags.HTTP_METHOD);
+ String method = methodObj instanceof String ? (String) methodObj : null;
+
+ Object statusCodeObj = span.getTag(Tags.HTTP_STATUS);
+ int statusCode = statusCodeObj instanceof Integer ? (Integer) statusCodeObj : 0;
+
+ if (route == null || method == null || statusCode == 0) {
+ return;
+ }
+
+ boolean sampled = apiAccessTracker.updateApiAccessIfExpired(route, method, statusCode);
+ if (!sampled) {
return;
}
diff --git a/dd-java-agent/appsec/src/test/groovy/com/datadog/appsec/api/security/ApiAccessTrackerTest.groovy b/dd-java-agent/appsec/src/test/groovy/com/datadog/appsec/api/security/ApiAccessTrackerTest.groovy
new file mode 100644
index 00000000000..7cf49042110
--- /dev/null
+++ b/dd-java-agent/appsec/src/test/groovy/com/datadog/appsec/api/security/ApiAccessTrackerTest.groovy
@@ -0,0 +1,55 @@
+package com.datadog.appsec.api.security
+
+import datadog.trace.test.util.DDSpecification
+
+class ApiAccessTrackerTest extends DDSpecification {
+ def "should add new api access and update if expired"() {
+ given: "An ApiAccessTracker with capacity 2 and expiration time 1 second"
+ def tracker = new ApiAccessTracker(2, 1000)
+
+ when: "Adding new api access"
+ tracker.updateApiAccessIfExpired("route1", "GET", 200)
+ def firstAccessTime = tracker.apiAccessLog.values().iterator().next()
+
+ then: "The access is added"
+ tracker.apiAccessLog.size() == 1
+
+ when: "Waiting more than expiration time and adding another access with the same key"
+ Thread.sleep(1100) // Waiting more than 1 second to ensure expiration
+ tracker.updateApiAccessIfExpired("route1", "GET", 200)
+ def secondAccessTime = tracker.apiAccessLog.values().iterator().next()
+
+ then: "The access is updated and moved to the end"
+ tracker.apiAccessLog.size() == 1
+ secondAccessTime > firstAccessTime
+ }
+
+ def "should remove the oldest access when capacity is exceeded"() {
+ given: "An ApiAccessTracker with capacity 1"
+ def tracker = new ApiAccessTracker(1, 1000)
+
+ when: "Adding two api accesses"
+ tracker.updateApiAccessIfExpired("route1", "GET", 200)
+ Thread.sleep(100) // Delay to ensure different timestamps
+ tracker.updateApiAccessIfExpired("route2", "POST", 404)
+
+ then: "The oldest access is removed"
+ tracker.apiAccessLog.size() == 1
+ !tracker.apiAccessLog.containsKey(tracker.computeApiHash("route1", "GET", 200))
+ tracker.apiAccessLog.containsKey(tracker.computeApiHash("route2", "POST", 404))
+ }
+
+ def "should not update access if not expired"() {
+ given: "An ApiAccessTracker with a short expiration time"
+ def tracker = new ApiAccessTracker(2, 2000) // 2 seconds expiration
+
+ when: "Adding an api access and updating it before it expires"
+ tracker.updateApiAccessIfExpired("route1", "GET", 200)
+ def updateTime = System.currentTimeMillis()
+ boolean updatedBeforeExpiration = tracker.updateApiAccessIfExpired("route1", "GET", 200)
+
+ then: "The access is not updated"
+ !updatedBeforeExpiration
+ tracker.apiAccessLog.get(tracker.computeApiHash("route1", "GET", 200)) == updateTime
+ }
+}
\ No newline at end of file
diff --git a/dd-java-agent/appsec/src/test/groovy/com/datadog/appsec/api/security/ApiSecurityRequestSamplerTest.groovy b/dd-java-agent/appsec/src/test/groovy/com/datadog/appsec/api/security/ApiSecurityRequestSamplerTest.groovy
deleted file mode 100644
index ec59242f4e1..00000000000
--- a/dd-java-agent/appsec/src/test/groovy/com/datadog/appsec/api/security/ApiSecurityRequestSamplerTest.groovy
+++ /dev/null
@@ -1,78 +0,0 @@
-package com.datadog.appsec.api.security
-
-import com.datadog.appsec.config.AppSecFeatures
-import com.datadog.appsec.config.AppSecFeaturesDeserializer
-import datadog.remoteconfig.ConfigurationChangesTypedListener
-import datadog.remoteconfig.ConfigurationPoller
-import datadog.remoteconfig.Product
-import datadog.trace.api.Config
-import datadog.trace.test.util.DDSpecification
-import spock.lang.Shared
-
-class ApiSecurityRequestSamplerTest extends DDSpecification {
-
- @Shared
- static final float DEFAULT_SAMPLE_RATE = Config.get().getApiSecurityRequestSampleRate()
-
- void 'Api Security Request Sample Rate'() {
- given:
- def config = Spy(Config.get())
- config.getApiSecurityRequestSampleRate() >> sampleRate
- def sampler = new ApiSecurityRequestSampler(config)
-
- when:
- def numOfRequest = expectedSampledRequests.size()
- def results = new int[numOfRequest]
- for (int i = 0; i < numOfRequest; i++) {
- results[i] = sampler.sampleRequest() ? 1 : 0
- }
-
- then:
- results == expectedSampledRequests as int[]
-
- where:
- sampleRate | expectedSampledRequests
- DEFAULT_SAMPLE_RATE | [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0] // Default sample rate - 10%
- 0.0 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
- 0.1 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0]
- 0.25 | [0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1]
- 0.33 | [0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1]
- 0.5 | [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
- 0.75 | [0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1]
- 0.9 | [0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]
- 0.99 | [0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
- 1.0 | [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
- 1.25 | [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] // Wrong sample rate - use 100%
- -0.5 | [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] // Wrong sample rate - use 100%
- }
-
- void 'update sample rate via remote-config'() {
- given:
- ConfigurationPoller poller = Mock()
- def config = Spy(Config.get())
- ConfigurationChangesTypedListener listener
- AppSecFeatures newConfig = new AppSecFeatures().tap {
- asm = new AppSecFeatures.Asm().tap {
- enabled = true
- }
- apiSecurity = new AppSecFeatures.ApiSecurity().tap {
- requestSampleRate = 0.2
- }
- }
-
- when:
- def sampler = new ApiSecurityRequestSampler(config, poller)
-
- then:
- 1 * poller.addListener(Product.ASM_FEATURES, 'asm_api_security', AppSecFeaturesDeserializer.INSTANCE, _) >> {
- listener = it[3] as ConfigurationChangesTypedListener
- }
- listener != null
-
- when:
- listener.accept(null, newConfig, null)
-
- then:
- sampler.sampling == 20
- }
-}
diff --git a/dd-java-agent/appsec/src/test/groovy/com/datadog/appsec/gateway/GatewayBridgeIGRegistrationSpecification.groovy b/dd-java-agent/appsec/src/test/groovy/com/datadog/appsec/gateway/GatewayBridgeIGRegistrationSpecification.groovy
index 7457fd23e98..eee76de12e1 100644
--- a/dd-java-agent/appsec/src/test/groovy/com/datadog/appsec/gateway/GatewayBridgeIGRegistrationSpecification.groovy
+++ b/dd-java-agent/appsec/src/test/groovy/com/datadog/appsec/gateway/GatewayBridgeIGRegistrationSpecification.groovy
@@ -10,7 +10,7 @@ class GatewayBridgeIGRegistrationSpecification extends DDSpecification {
SubscriptionService ig = Mock()
EventDispatcher eventDispatcher = Mock()
- GatewayBridge bridge = new GatewayBridge(ig, eventDispatcher, null, null, [])
+ GatewayBridge bridge = new GatewayBridge(ig, eventDispatcher, null, [])
void 'request_body_start and request_body_done are registered'() {
given:
diff --git a/dd-java-agent/appsec/src/test/groovy/com/datadog/appsec/gateway/GatewayBridgeSpecification.groovy b/dd-java-agent/appsec/src/test/groovy/com/datadog/appsec/gateway/GatewayBridgeSpecification.groovy
index d5f53e1979a..b345173e3bc 100644
--- a/dd-java-agent/appsec/src/test/groovy/com/datadog/appsec/gateway/GatewayBridgeSpecification.groovy
+++ b/dd-java-agent/appsec/src/test/groovy/com/datadog/appsec/gateway/GatewayBridgeSpecification.groovy
@@ -24,6 +24,7 @@ import datadog.trace.bootstrap.instrumentation.api.URIDataAdapter
import datadog.trace.bootstrap.instrumentation.api.URIDataAdapterBase
import datadog.trace.test.util.DDSpecification
+import java.util.function.BiConsumer
import java.util.function.BiFunction
import java.util.function.Function
import java.util.function.Supplier
@@ -49,6 +50,15 @@ class GatewayBridgeSpecification extends DDSpecification {
GatewayBridgeSpecification.this.traceSegment
}
+ @Override
+ boolean isRequiresPostProcessing() {
+ return false
+ }
+
+ @Override
+ void setRequiresPostProcessing(boolean postProcessing) {
+ }
+
@Override
void close() throws IOException {}
}
@@ -60,7 +70,7 @@ class GatewayBridgeSpecification extends DDSpecification {
RateLimiter rateLimiter = new RateLimiter(10, { -> 0L } as TimeSource, RateLimiter.ThrottledCallback.NOOP)
TraceSegmentPostProcessor pp = Mock()
- GatewayBridge bridge = new GatewayBridge(ig, eventDispatcher, rateLimiter, null, [pp])
+ GatewayBridge bridge = new GatewayBridge(ig, eventDispatcher, rateLimiter, [pp])
Supplier> requestStartedCB
BiFunction> requestEndedCB
@@ -78,6 +88,7 @@ class GatewayBridgeSpecification extends DDSpecification {
Function> respHeadersDoneCB
BiFunction> grpcServerRequestMessageCB
BiFunction, Flow> graphqlServerRequestMessageCB
+ BiConsumer postProcessingCB
void setup() {
callInitAndCaptureCBs()
@@ -134,7 +145,6 @@ class GatewayBridgeSpecification extends DDSpecification {
1 * spanInfo.getTags() >> ['http.client_ip':'1.1.1.1']
1 * mockAppSecCtx.transferCollectedEvents() >> [event]
1 * mockAppSecCtx.peerAddress >> '2001::1'
- 1 * mockAppSecCtx.close()
1 * traceSegment.setTagTop('manual.keep', true)
1 * traceSegment.setTagTop("_dd.appsec.enabled", 1)
1 * traceSegment.setTagTop("_dd.runtime_family", "jvm")
@@ -143,7 +153,6 @@ class GatewayBridgeSpecification extends DDSpecification {
1 * traceSegment.setTagTop('http.request.headers.accept', 'header_value')
1 * traceSegment.setTagTop('http.response.headers.content-type', 'text/html; charset=UTF-8')
1 * traceSegment.setTagTop('network.client.ip', '2001::1')
- 1 * mockAppSecCtx.closeAdditive()
flow.result == null
flow.action == Flow.Action.Noop.INSTANCE
}
@@ -163,8 +172,6 @@ class GatewayBridgeSpecification extends DDSpecification {
then:
11 * mockAppSecCtx.transferCollectedEvents() >> [event]
- 11 * mockAppSecCtx.close()
- 11 * mockAppSecCtx.closeAdditive()
10 * spanInfo.getTags() >> ['http.client_ip':'1.1.1.1']
10 * traceSegment.setDataTop("appsec", _)
}
@@ -412,6 +419,7 @@ class GatewayBridgeSpecification extends DDSpecification {
1 * ig.registerCallback(EVENTS.responseHeaderDone(), _) >> { respHeadersDoneCB = it[1]; null }
1 * ig.registerCallback(EVENTS.grpcServerRequestMessage(), _) >> { grpcServerRequestMessageCB = it[1]; null }
1 * ig.registerCallback(EVENTS.graphqlServerRequestMessage(), _) >> { graphqlServerRequestMessageCB = it[1]; null }
+ 1 * ig.registerCallback(EVENTS.postProcessing(), _) >> { postProcessingCB = it[1]; null }
0 * ig.registerCallback(_, _)
bridge.init()
@@ -731,6 +739,14 @@ class GatewayBridgeSpecification extends DDSpecification {
GatewayBridgeSpecification.this.traceSegment
}
+ @Override
+ void setRequiresPostProcessing(boolean postProcessing) {}
+
+ @Override
+ boolean isRequiresPostProcessing() {
+ return false
+ }
+
@Override
void close() throws IOException {}
}
diff --git a/dd-trace-core/src/main/java/datadog/trace/common/writer/DDAgentWriter.java b/dd-trace-core/src/main/java/datadog/trace/common/writer/DDAgentWriter.java
index 0c85cb49abc..e36da816175 100644
--- a/dd-trace-core/src/main/java/datadog/trace/common/writer/DDAgentWriter.java
+++ b/dd-trace-core/src/main/java/datadog/trace/common/writer/DDAgentWriter.java
@@ -14,6 +14,8 @@
import datadog.trace.common.writer.ddagent.DDAgentMapperDiscovery;
import datadog.trace.common.writer.ddagent.Prioritization;
import datadog.trace.core.monitor.HealthMetrics;
+import datadog.trace.core.postprocessor.AppSecSpanPostProcessor;
+import datadog.trace.core.postprocessor.SpanPostProcessor;
import java.util.concurrent.TimeUnit;
import okhttp3.HttpUrl;
import okhttp3.OkHttpClient;
@@ -145,6 +147,7 @@ public DDAgentWriter build() {
final DDAgentMapperDiscovery mapperDiscovery = new DDAgentMapperDiscovery(featureDiscovery);
final PayloadDispatcher dispatcher =
new PayloadDispatcherImpl(mapperDiscovery, agentApi, healthMetrics, monitoring);
+ final SpanPostProcessor spanPostProcessor = new AppSecSpanPostProcessor();
final TraceProcessingWorker traceProcessingWorker =
new TraceProcessingWorker(
traceBufferSize,
@@ -155,7 +158,7 @@ public DDAgentWriter build() {
flushIntervalMilliseconds,
TimeUnit.MILLISECONDS,
singleSpanSampler,
- null);
+ spanPostProcessor);
return new DDAgentWriter(traceProcessingWorker, dispatcher, healthMetrics, alwaysFlush);
}
diff --git a/dd-trace-core/src/main/java/datadog/trace/common/writer/DDIntakeWriter.java b/dd-trace-core/src/main/java/datadog/trace/common/writer/DDIntakeWriter.java
index 7075266352c..6e918be2435 100644
--- a/dd-trace-core/src/main/java/datadog/trace/common/writer/DDIntakeWriter.java
+++ b/dd-trace-core/src/main/java/datadog/trace/common/writer/DDIntakeWriter.java
@@ -9,6 +9,8 @@
import datadog.trace.common.writer.ddagent.Prioritization;
import datadog.trace.common.writer.ddintake.DDIntakeMapperDiscovery;
import datadog.trace.core.monitor.HealthMetrics;
+import datadog.trace.core.postprocessor.AppSecSpanPostProcessor;
+import datadog.trace.core.postprocessor.SpanPostProcessor;
import java.util.EnumMap;
import java.util.Map;
import java.util.concurrent.TimeUnit;
@@ -113,6 +115,8 @@ public DDIntakeWriter build() {
dispatcher = new CompositePayloadDispatcher(dispatchers);
}
+ SpanPostProcessor spanPostProcessor = new AppSecSpanPostProcessor();
+
final TraceProcessingWorker traceProcessingWorker =
new TraceProcessingWorker(
traceBufferSize,
@@ -123,7 +127,7 @@ public DDIntakeWriter build() {
flushIntervalMilliseconds,
TimeUnit.MILLISECONDS,
singleSpanSampler,
- null);
+ spanPostProcessor);
return new DDIntakeWriter(
traceProcessingWorker,
diff --git a/dd-trace-core/src/main/java/datadog/trace/common/writer/TraceProcessingWorker.java b/dd-trace-core/src/main/java/datadog/trace/common/writer/TraceProcessingWorker.java
index c17c00932bb..2d40384fa03 100644
--- a/dd-trace-core/src/main/java/datadog/trace/common/writer/TraceProcessingWorker.java
+++ b/dd-trace-core/src/main/java/datadog/trace/common/writer/TraceProcessingWorker.java
@@ -7,6 +7,7 @@
import datadog.communication.ddagent.DroppingPolicy;
import datadog.trace.api.Config;
+import datadog.trace.api.gateway.RequestContext;
import datadog.trace.common.sampling.SingleSpanSampler;
import datadog.trace.common.writer.ddagent.FlushEvent;
import datadog.trace.common.writer.ddagent.Prioritization;
@@ -16,6 +17,7 @@
import datadog.trace.core.DDSpanContext;
import datadog.trace.core.monitor.HealthMetrics;
import datadog.trace.core.postprocessor.SpanPostProcessor;
+import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CountDownLatch;
@@ -270,6 +272,7 @@ public void onEvent(Object event) {
if (event instanceof List) {
List trace = (List) event;
maybeTracePostProcessing(trace);
+ closeRequestContext(trace);
// TODO populate `_sample_rate` metric in a way that accounts for lost/dropped traces
payloadDispatcher.addTrace(trace);
} else if (event instanceof FlushEvent) {
@@ -369,5 +372,23 @@ private void maybeTracePostProcessing(List trace) {
}
}
}
+
+ private void closeRequestContext(List trace) {
+ if (trace == null || trace.isEmpty()) {
+ return;
+ }
+
+ DDSpan rootSpan = trace.get(0);
+ RequestContext requestContext = rootSpan.getRequestContext();
+ if (requestContext == null) {
+ return;
+ }
+
+ try {
+ requestContext.close();
+ } catch (IOException e) {
+ log.warn("Error closing request context data", e);
+ }
+ }
}
}
diff --git a/dd-trace-core/src/main/java/datadog/trace/core/CoreTracer.java b/dd-trace-core/src/main/java/datadog/trace/core/CoreTracer.java
index d2472114d52..10893dae706 100644
--- a/dd-trace-core/src/main/java/datadog/trace/core/CoreTracer.java
+++ b/dd-trace-core/src/main/java/datadog/trace/core/CoreTracer.java
@@ -956,17 +956,6 @@ void write(final List trace) {
}
if (null != rootSpan) {
onRootSpanFinished(rootSpan, rootSpan.getEndpointTracker());
-
- // request context is propagated to contexts in child spans
- // Assume here that if present it will be so starting in the top span
- RequestContext requestContext = rootSpan.getRequestContext();
- if (requestContext != null) {
- try {
- requestContext.close();
- } catch (IOException e) {
- log.warn("Error closing request context data", e);
- }
- }
}
}
diff --git a/dd-trace-core/src/main/java/datadog/trace/core/DDSpanContext.java b/dd-trace-core/src/main/java/datadog/trace/core/DDSpanContext.java
index eb11bb93b40..cc56ae6ebbe 100644
--- a/dd-trace-core/src/main/java/datadog/trace/core/DDSpanContext.java
+++ b/dd-trace-core/src/main/java/datadog/trace/core/DDSpanContext.java
@@ -946,10 +946,12 @@ private String getTagName(String key) {
return "_dd." + key + ".json";
}
+ @Override
public void setRequiresPostProcessing(boolean postProcessing) {
this.requiresPostProcessing = postProcessing;
}
+ @Override
public boolean isRequiresPostProcessing() {
return requiresPostProcessing;
}
diff --git a/dd-trace-core/src/main/java/datadog/trace/core/postprocessor/AppSecSpanPostProcessor.java b/dd-trace-core/src/main/java/datadog/trace/core/postprocessor/AppSecSpanPostProcessor.java
new file mode 100644
index 00000000000..3fe40a66e4a
--- /dev/null
+++ b/dd-trace-core/src/main/java/datadog/trace/core/postprocessor/AppSecSpanPostProcessor.java
@@ -0,0 +1,48 @@
+package datadog.trace.core.postprocessor;
+
+import static datadog.trace.api.gateway.Events.EVENTS;
+
+import datadog.trace.api.gateway.CallbackProvider;
+import datadog.trace.api.gateway.RequestContext;
+import datadog.trace.api.gateway.RequestContextSlot;
+import datadog.trace.bootstrap.instrumentation.api.AgentSpan;
+import datadog.trace.bootstrap.instrumentation.api.AgentTracer;
+import datadog.trace.core.DDSpan;
+import datadog.trace.core.DDSpanContext;
+import java.util.function.BiConsumer;
+import java.util.function.BooleanSupplier;
+
+public class AppSecSpanPostProcessor implements SpanPostProcessor {
+
+ // Extract this to allow for easier testing
+ protected AgentTracer.TracerAPI tracer() {
+ return AgentTracer.get();
+ }
+
+ @Override
+ public boolean process(DDSpan span, BooleanSupplier timeoutCheck) {
+ DDSpanContext context = span.context();
+ if (context == null) {
+ return false;
+ }
+
+ CallbackProvider cbp = tracer().getCallbackProvider(RequestContextSlot.APPSEC);
+ if (cbp == null) {
+ return false;
+ }
+
+ RequestContext ctx = span.getRequestContext();
+ if (ctx == null) {
+ return false;
+ }
+
+ BiConsumer postProcessingCallback =
+ cbp.getCallback(EVENTS.postProcessing());
+ if (postProcessingCallback == null) {
+ return false;
+ }
+
+ postProcessingCallback.accept(ctx, span);
+ return true;
+ }
+}
diff --git a/dd-trace-core/src/test/groovy/datadog/trace/core/DDSpanTest.groovy b/dd-trace-core/src/test/groovy/datadog/trace/core/DDSpanTest.groovy
index 8fbcd58005d..d79b62bad0f 100644
--- a/dd-trace-core/src/test/groovy/datadog/trace/core/DDSpanTest.groovy
+++ b/dd-trace-core/src/test/groovy/datadog/trace/core/DDSpanTest.groovy
@@ -3,7 +3,6 @@ package datadog.trace.core
import datadog.trace.api.DDSpanId
import datadog.trace.api.DDTags
import datadog.trace.api.DDTraceId
-import datadog.trace.api.gateway.RequestContextSlot
import datadog.trace.api.sampling.PrioritySampling
import datadog.trace.bootstrap.instrumentation.api.AgentSpan
import datadog.trace.bootstrap.instrumentation.api.AgentTracer.NoopPathwayContext
@@ -319,30 +318,6 @@ class DDSpanTest extends DDCoreSpecification {
new ExtractedContext(DDTraceId.from(123), 456, PrioritySampling.SAMPLER_KEEP, "789", propagationTagsFactory.empty(), DATADOG) | false
}
- def 'publishing of root span closes the request context data'() {
- setup:
- def reqContextData = Mock(Closeable)
- def context = new TagContext().withRequestContextDataAppSec(reqContextData)
- def root = tracer.buildSpan("root").asChildOf(context).start()
- def child = tracer.buildSpan("child").asChildOf(root).start()
-
- expect:
- root.requestContext.getData(RequestContextSlot.APPSEC).is(reqContextData)
- child.requestContext.getData(RequestContextSlot.APPSEC).is(reqContextData)
-
- when:
- child.finish()
-
- then:
- 0 * reqContextData.close()
-
- when:
- root.finish()
-
- then:
- 1 * reqContextData.close()
- }
-
def "infer top level from parent service name"() {
setup:
def propagationTagsFactory = tracer.getPropagationTagsFactory()
diff --git a/dd-trace-core/src/test/groovy/datadog/trace/core/postprocessor/AppSecSpanPostProcessorTest.groovy b/dd-trace-core/src/test/groovy/datadog/trace/core/postprocessor/AppSecSpanPostProcessorTest.groovy
new file mode 100644
index 00000000000..eb916324d12
--- /dev/null
+++ b/dd-trace-core/src/test/groovy/datadog/trace/core/postprocessor/AppSecSpanPostProcessorTest.groovy
@@ -0,0 +1,122 @@
+package datadog.trace.core.postprocessor
+
+import datadog.trace.api.gateway.CallbackProvider
+import datadog.trace.api.gateway.RequestContext
+import datadog.trace.api.gateway.RequestContextSlot
+import datadog.trace.bootstrap.instrumentation.api.AgentTracer
+import datadog.trace.core.DDSpan
+import datadog.trace.core.DDSpanContext
+import datadog.trace.core.PendingTrace
+import datadog.trace.test.util.DDSpecification
+
+import java.util.function.BiConsumer
+import java.util.function.BooleanSupplier
+
+import static datadog.trace.api.gateway.Events.EVENTS
+
+
+class AppSecSpanPostProcessorTest extends DDSpecification {
+ def "process returns false if span context is null"() {
+ given:
+ def processor = new AppSecSpanPostProcessor()
+ def span = Mock(DDSpan)
+ def timeoutCheck = Mock(BooleanSupplier)
+ (span.context()) >> null
+
+ expect:
+ !processor.process(span, timeoutCheck)
+ }
+
+ def "process returns false if callback provider is null"() {
+ given:
+ AgentTracer.TracerAPI tracer = Mock(AgentTracer.TracerAPI)
+ tracer.getCallbackProvider(RequestContextSlot.APPSEC) >> null
+ def processor = new AppSecSpanPostProcessor() {
+ @Override
+ protected AgentTracer.TracerAPI tracer() {
+ return tracer
+ }
+ }
+ def span = Mock(DDSpan) {
+ context() >> Mock(DDSpanContext)
+ }
+ def timeoutCheck = Mock(BooleanSupplier)
+
+ expect:
+ !processor.process(span, timeoutCheck)
+ }
+
+ def "process returns false if request context is null"() {
+ given:
+ AgentTracer.TracerAPI tracer = Mock(AgentTracer.TracerAPI)
+ def cbp = Mock(CallbackProvider)
+ tracer.getCallbackProvider(RequestContextSlot.APPSEC) >> cbp
+ def processor = new AppSecSpanPostProcessor() {
+ @Override
+ protected AgentTracer.TracerAPI tracer() {
+ return tracer
+ }
+ }
+ def span = Mock(DDSpan) {
+ context() >> Mock(DDSpanContext)
+ getRequestContext() >> null
+ }
+ def timeoutCheck = Mock(BooleanSupplier)
+
+ expect:
+ !processor.process(span, timeoutCheck)
+ }
+
+ def "process returns false if post-processing callback is null"() {
+ given:
+ AgentTracer.TracerAPI tracer = Mock(AgentTracer.TracerAPI)
+ def cbp = Mock(CallbackProvider)
+ tracer.getCallbackProvider(RequestContextSlot.APPSEC) >> cbp
+ cbp.getCallback(EVENTS.postProcessing()) >> null
+ def processor = new AppSecSpanPostProcessor() {
+ @Override
+ protected AgentTracer.TracerAPI tracer() {
+ return tracer
+ }
+ }
+ def span = Mock(DDSpan) {
+ context() >> Mock(DDSpanContext)
+ getRequestContext() >> Mock(RequestContext)
+ }
+ def timeoutCheck = Mock(BooleanSupplier)
+
+ expect:
+ !processor.process(span, timeoutCheck)
+ }
+
+ def "process returns true when all components are properly configured"() {
+ given:
+ def callback = Mock(BiConsumer)
+ AgentTracer.TracerAPI tracer = Mock(AgentTracer.TracerAPI)
+ def cbp = Mock(CallbackProvider)
+ tracer.getCallbackProvider(RequestContextSlot.APPSEC) >> cbp
+ cbp.getCallback(EVENTS.postProcessing()) >> callback
+ def processor = new AppSecSpanPostProcessor() {
+ @Override
+ protected AgentTracer.TracerAPI tracer() {
+ return tracer
+ }
+ }
+ def span = DDSpan.create("test", 0, Mock(DDSpanContext) {
+ isRequiresPostProcessing() >> true
+ getTrace() >> Mock(PendingTrace) {
+ getCurrentTimeNano() >> 0
+ }
+ getRequestContext() >> Mock(RequestContext)
+ }, [])
+ def timeoutCheck = Mock(BooleanSupplier)
+
+ when:
+ boolean result = processor.process(span, timeoutCheck)
+
+ then:
+ result
+ 1 * callback.accept(_, _)
+ }
+}
+
diff --git a/internal-api/src/main/java/datadog/trace/api/gateway/Events.java b/internal-api/src/main/java/datadog/trace/api/gateway/Events.java
index e4472874b6d..d6e394bf3c5 100644
--- a/internal-api/src/main/java/datadog/trace/api/gateway/Events.java
+++ b/internal-api/src/main/java/datadog/trace/api/gateway/Events.java
@@ -3,9 +3,11 @@
import datadog.trace.api.function.TriConsumer;
import datadog.trace.api.function.TriFunction;
import datadog.trace.api.http.StoredBodySupplier;
+import datadog.trace.bootstrap.instrumentation.api.AgentSpan;
import datadog.trace.bootstrap.instrumentation.api.URIDataAdapter;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
+import java.util.function.BiConsumer;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.function.Supplier;
@@ -202,6 +204,17 @@ public EventType>> grpcServerReque
GRAPHQL_SERVER_REQUEST_MESSAGE;
}
+ static final int POST_PROCESSING_ID = 16;
+
+ @SuppressWarnings("rawtypes")
+ private static final EventType POST_PROCESSING =
+ new ET<>("trace.post.processing", POST_PROCESSING_ID);
+
+ @SuppressWarnings("unchecked")
+ public EventType> postProcessing() {
+ return (EventType>) POST_PROCESSING;
+ }
+
static final int MAX_EVENTS = nextId.get();
private static final class ET extends EventType {
diff --git a/internal-api/src/main/java/datadog/trace/api/gateway/InstrumentationGateway.java b/internal-api/src/main/java/datadog/trace/api/gateway/InstrumentationGateway.java
index 128fdfd156f..633de5c0645 100644
--- a/internal-api/src/main/java/datadog/trace/api/gateway/InstrumentationGateway.java
+++ b/internal-api/src/main/java/datadog/trace/api/gateway/InstrumentationGateway.java
@@ -3,6 +3,7 @@
import static datadog.trace.api.gateway.Events.GRAPHQL_SERVER_REQUEST_MESSAGE_ID;
import static datadog.trace.api.gateway.Events.GRPC_SERVER_REQUEST_MESSAGE_ID;
import static datadog.trace.api.gateway.Events.MAX_EVENTS;
+import static datadog.trace.api.gateway.Events.POST_PROCESSING_ID;
import static datadog.trace.api.gateway.Events.REQUEST_BODY_CONVERTED_ID;
import static datadog.trace.api.gateway.Events.REQUEST_BODY_DONE_ID;
import static datadog.trace.api.gateway.Events.REQUEST_BODY_START_ID;
@@ -21,9 +22,11 @@
import datadog.trace.api.function.TriConsumer;
import datadog.trace.api.function.TriFunction;
import datadog.trace.api.http.StoredBodySupplier;
+import datadog.trace.bootstrap.instrumentation.api.AgentSpan;
import datadog.trace.bootstrap.instrumentation.api.URIDataAdapter;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReferenceArray;
+import java.util.function.BiConsumer;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.function.Supplier;
@@ -360,6 +363,18 @@ public Flow apply(RequestContext ctx, Integer status) {
}
}
};
+ case POST_PROCESSING_ID:
+ return (C)
+ new BiConsumer() {
+ @Override
+ public void accept(RequestContext ctx, AgentSpan span) {
+ try {
+ ((BiConsumer) callback).accept(ctx, span);
+ } catch (Throwable t) {
+ log.warn("Callback for {} threw.", eventType, t);
+ }
+ }
+ };
default:
log.warn("Unwrapped callback for {}", eventType);
return callback;
diff --git a/internal-api/src/main/java/datadog/trace/api/gateway/RequestContext.java b/internal-api/src/main/java/datadog/trace/api/gateway/RequestContext.java
index cef89020eab..07b2fbd1dc4 100644
--- a/internal-api/src/main/java/datadog/trace/api/gateway/RequestContext.java
+++ b/internal-api/src/main/java/datadog/trace/api/gateway/RequestContext.java
@@ -17,6 +17,10 @@ public interface RequestContext extends Closeable {
BlockResponseFunction getBlockResponseFunction();
+ void setRequiresPostProcessing(boolean postProcessing);
+
+ boolean isRequiresPostProcessing();
+
class Noop implements RequestContext {
public static final RequestContext INSTANCE = new Noop();
@@ -40,6 +44,14 @@ public BlockResponseFunction getBlockResponseFunction() {
return null;
}
+ @Override
+ public void setRequiresPostProcessing(boolean postProcessing) {}
+
+ @Override
+ public boolean isRequiresPostProcessing() {
+ return false;
+ }
+
@Override
public void close() throws IOException {}
}
diff --git a/internal-api/src/test/java/datadog/trace/api/gateway/InstrumentationGatewayTest.java b/internal-api/src/test/java/datadog/trace/api/gateway/InstrumentationGatewayTest.java
index 78b3e0d3820..7fed0439e4e 100644
--- a/internal-api/src/test/java/datadog/trace/api/gateway/InstrumentationGatewayTest.java
+++ b/internal-api/src/test/java/datadog/trace/api/gateway/InstrumentationGatewayTest.java
@@ -56,6 +56,14 @@ public void setBlockResponseFunction(BlockResponseFunction blockResponseFunction
public BlockResponseFunction getBlockResponseFunction() {
return null;
}
+
+ @Override
+ public void setRequiresPostProcessing(boolean postProcessing) {}
+
+ @Override
+ public boolean isRequiresPostProcessing() {
+ return false;
+ }
};
flow = new Flow.ResultFlow<>(null);
callback = new Callback(context, flow);
@@ -197,6 +205,8 @@ public void testNormalCalls() {
ss.registerCallback(events.graphqlServerRequestMessage(), callback);
assertThat(cbp.getCallback(events.graphqlServerRequestMessage()).apply(null, null).getAction())
.isEqualTo(Flow.Action.Noop.INSTANCE);
+ ss.registerCallback(events.postProcessing(), callback);
+ cbp.getCallback(events.postProcessing()).accept(null, null);
assertThat(callback.count).isEqualTo(Events.MAX_EVENTS);
}
@@ -246,6 +256,8 @@ public void testThrowableBlocking() {
ss.registerCallback(events.graphqlServerRequestMessage(), throwback);
assertThat(cbp.getCallback(events.graphqlServerRequestMessage()).apply(null, null).getAction())
.isEqualTo(Flow.Action.Noop.INSTANCE);
+ ss.registerCallback(events.postProcessing(), throwback);
+ cbp.getCallback(events.postProcessing()).accept(null, null);
assertThat(throwback.count).isEqualTo(Events.MAX_EVENTS);
}