Skip to content

Commit 105436c

Browse files
authored
fix: guard inflight_requests and request_duration from early returns. (#2576)
1 parent 174389e commit 105436c

File tree

1 file changed

+25
-7
lines changed

1 file changed

+25
-7
lines changed

lib/runtime/src/pipeline/network/ingress/push_handler.rs

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ use crate::protocols::maybe_error::MaybeError;
1818
use prometheus::{Histogram, IntCounter, IntCounterVec, IntGauge};
1919
use serde::{Deserialize, Serialize};
2020
use std::sync::Arc;
21+
use std::time::Instant;
2122
use tracing::info_span;
2223
use tracing::Instrument;
2324

@@ -106,6 +107,20 @@ impl WorkHandlerMetrics {
106107
}
107108
}
108109

110+
// RAII guard to ensure inflight gauge is decremented and request duration is observed on all code paths.
111+
struct RequestMetricsGuard {
112+
inflight_requests: prometheus::IntGauge,
113+
request_duration: prometheus::Histogram,
114+
start_time: Instant,
115+
}
116+
impl Drop for RequestMetricsGuard {
117+
fn drop(&mut self) {
118+
self.inflight_requests.dec();
119+
self.request_duration
120+
.observe(self.start_time.elapsed().as_secs_f64());
121+
}
122+
}
123+
109124
#[async_trait]
110125
impl<T: Data, U: Data> PushWorkHandler for Ingress<SingleIn<T>, ManyOut<U>>
111126
where
@@ -125,11 +140,17 @@ where
125140
async fn handle_payload(&self, payload: Bytes) -> Result<(), PipelineError> {
126141
let start_time = std::time::Instant::now();
127142

128-
if let Some(m) = self.metrics() {
143+
// Increment inflight and ensure it's decremented on all exits via RAII guard
144+
let _inflight_guard = self.metrics().map(|m| {
129145
m.request_counter.inc();
130146
m.inflight_requests.inc();
131147
m.request_bytes.inc_by(payload.len() as u64);
132-
}
148+
RequestMetricsGuard {
149+
inflight_requests: m.inflight_requests.clone(),
150+
request_duration: m.request_duration.clone(),
151+
start_time,
152+
}
153+
});
133154

134155
// decode the control message and the request
135156
let msg = TwoPartCodec::default()
@@ -292,11 +313,8 @@ where
292313
}
293314
}
294315

295-
if let Some(m) = self.metrics() {
296-
let duration = start_time.elapsed();
297-
m.request_duration.observe(duration.as_secs_f64());
298-
m.inflight_requests.dec();
299-
}
316+
// Ensure the metrics guard is not dropped until the end of the function.
317+
drop(_inflight_guard);
300318

301319
Ok(())
302320
}

0 commit comments

Comments
 (0)