Skip to content

Commit ee86bad

Browse files
authored
feat: Validation engine for validating OpenAI api request data (#1674)
1 parent f0652d8 commit ee86bad

File tree

7 files changed

+663
-63
lines changed

7 files changed

+663
-63
lines changed

lib/llm/src/engines.rs

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,19 @@ fn delta_core(tok: u32) -> Annotated<LLMEngineOutput> {
124124
/// Useful for testing ingress such as service-http.
125125
struct EchoEngineFull {}
126126

127+
/// Validate Engine that verifies request data
128+
pub struct ValidateEngine<E> {
129+
inner: E,
130+
}
131+
132+
impl<E> ValidateEngine<E> {
133+
pub fn new(inner: E) -> Self {
134+
Self { inner }
135+
}
136+
}
137+
127138
/// Engine that dispatches requests to either OpenAICompletions
128-
//or OpenAIChatCompletions engine
139+
/// or OpenAIChatCompletions engine
129140
pub struct EngineDispatcher<E> {
130141
inner: E,
131142
}
@@ -136,6 +147,11 @@ impl<E> EngineDispatcher<E> {
136147
}
137148
}
138149

150+
/// Trait on request types that allows us to validate the data
151+
pub trait ValidateRequest {
152+
fn validate(&self) -> Result<(), anyhow::Error>;
153+
}
154+
139155
/// Trait that allows handling both completion and chat completions requests
140156
#[async_trait]
141157
pub trait StreamingEngine: Send + Sync {
@@ -267,6 +283,30 @@ impl
267283
}
268284
}
269285

286+
#[async_trait]
287+
impl<E, Req, Resp> AsyncEngine<SingleIn<Req>, ManyOut<Annotated<Resp>>, Error> for ValidateEngine<E>
288+
where
289+
E: AsyncEngine<SingleIn<Req>, ManyOut<Annotated<Resp>>, Error> + Send + Sync,
290+
Req: ValidateRequest + Send + Sync + 'static,
291+
Resp: Send + Sync + 'static,
292+
{
293+
async fn generate(
294+
&self,
295+
incoming_request: SingleIn<Req>,
296+
) -> Result<ManyOut<Annotated<Resp>>, Error> {
297+
let (request, context) = incoming_request.into_parts();
298+
299+
// Validate the request first
300+
if let Err(validation_error) = request.validate() {
301+
return Err(anyhow::anyhow!("Validation failed: {}", validation_error));
302+
}
303+
304+
// Forward to inner engine if validation passes
305+
let validated_request = SingleIn::rejoin(request, context);
306+
self.inner.generate(validated_request).await
307+
}
308+
}
309+
270310
#[async_trait]
271311
impl<E> StreamingEngine for EngineDispatcher<E>
272312
where

lib/llm/src/protocols/openai.rs

Lines changed: 4 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313
// See the License for the specific language governing permissions and
1414
// limitations under the License.
1515

16-
use std::fmt::Display;
17-
1816
use anyhow::Result;
1917
use serde::{Deserialize, Serialize};
2018

@@ -29,42 +27,11 @@ pub mod embeddings;
2927
pub mod models;
3028
pub mod nvext;
3129
pub mod responses;
30+
pub mod validate;
3231

33-
/// Minimum allowed value for OpenAI's `temperature` sampling option
34-
pub const MIN_TEMPERATURE: f32 = 0.0;
35-
36-
/// Maximum allowed value for OpenAI's `temperature` sampling option
37-
pub const MAX_TEMPERATURE: f32 = 2.0;
38-
39-
/// Allowed range of values for OpenAI's `temperature`` sampling option
40-
pub const TEMPERATURE_RANGE: (f32, f32) = (MIN_TEMPERATURE, MAX_TEMPERATURE);
41-
42-
/// Minimum allowed value for OpenAI's `top_p` sampling option
43-
pub const MIN_TOP_P: f32 = 0.0;
44-
45-
/// Maximum allowed value for OpenAI's `top_p` sampling option
46-
pub const MAX_TOP_P: f32 = 1.0;
47-
48-
/// Allowed range of values for OpenAI's `top_p` sampling option
49-
pub const TOP_P_RANGE: (f32, f32) = (MIN_TOP_P, MAX_TOP_P);
50-
51-
/// Minimum allowed value for OpenAI's `frequency_penalty` sampling option
52-
pub const MIN_FREQUENCY_PENALTY: f32 = -2.0;
53-
54-
/// Maximum allowed value for OpenAI's `frequency_penalty` sampling option
55-
pub const MAX_FREQUENCY_PENALTY: f32 = 2.0;
56-
57-
/// Allowed range of values for OpenAI's `frequency_penalty` sampling option
58-
pub const FREQUENCY_PENALTY_RANGE: (f32, f32) = (MIN_FREQUENCY_PENALTY, MAX_FREQUENCY_PENALTY);
59-
60-
/// Minimum allowed value for OpenAI's `presence_penalty` sampling option
61-
pub const MIN_PRESENCE_PENALTY: f32 = -2.0;
62-
63-
/// Maximum allowed value for OpenAI's `presence_penalty` sampling option
64-
pub const MAX_PRESENCE_PENALTY: f32 = 2.0;
65-
66-
/// Allowed range of values for OpenAI's `presence_penalty` sampling option
67-
pub const PRESENCE_PENALTY_RANGE: (f32, f32) = (MIN_PRESENCE_PENALTY, MAX_PRESENCE_PENALTY);
32+
use validate::{
33+
validate_range, FREQUENCY_PENALTY_RANGE, PRESENCE_PENALTY_RANGE, TEMPERATURE_RANGE, TOP_P_RANGE,
34+
};
6835

6936
#[derive(Serialize, Deserialize, Debug)]
7037
pub struct AnnotatedDelta<R> {
@@ -166,21 +133,6 @@ impl<T: OpenAIStopConditionsProvider> StopConditionsProvider for T {
166133
}
167134
}
168135

169-
// todo - move to common location
170-
fn validate_range<T>(value: Option<T>, range: &(T, T)) -> Result<Option<T>>
171-
where
172-
T: PartialOrd + Display,
173-
{
174-
if value.is_none() {
175-
return Ok(None);
176-
}
177-
let value = value.unwrap();
178-
if value < range.0 || value > range.1 {
179-
anyhow::bail!("Value {} is out of range [{}, {}]", value, range.0, range.1);
180-
}
181-
Ok(Some(value))
182-
}
183-
184136
pub trait DeltaGeneratorExt<ResponseType: Send + Sync + 'static + std::fmt::Debug>:
185137
Send + Sync + 'static
186138
{

lib/llm/src/protocols/openai/chat_completions.rs

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,12 @@ use dynamo_runtime::protocols::annotated::AnnotationsProvider;
1717
use serde::{Deserialize, Serialize};
1818
use validator::Validate;
1919

20-
use super::nvext::NvExt;
21-
use super::nvext::NvExtProvider;
22-
use super::OpenAISamplingOptionsProvider;
23-
use super::OpenAIStopConditionsProvider;
20+
use crate::engines::ValidateRequest;
21+
22+
use super::{
23+
nvext::NvExt, nvext::NvExtProvider, validate, OpenAISamplingOptionsProvider,
24+
OpenAIStopConditionsProvider,
25+
};
2426

2527
mod aggregator;
2628
mod delta;
@@ -174,3 +176,42 @@ impl OpenAIStopConditionsProvider for NvCreateChatCompletionRequest {
174176
self.nvext.as_ref()
175177
}
176178
}
179+
180+
/// Implements `ValidateRequest` for `NvCreateChatCompletionRequest`,
181+
/// allowing us to validate the data.
182+
impl ValidateRequest for NvCreateChatCompletionRequest {
183+
fn validate(&self) -> Result<(), anyhow::Error> {
184+
validate::validate_messages(&self.inner.messages)?;
185+
validate::validate_model(&self.inner.model)?;
186+
// none for store
187+
validate::validate_reasoning_effort(&self.inner.reasoning_effort)?;
188+
validate::validate_metadata(&self.inner.metadata)?;
189+
validate::validate_frequency_penalty(self.inner.frequency_penalty)?;
190+
validate::validate_logit_bias(&self.inner.logit_bias)?;
191+
// none for logprobs
192+
validate::validate_top_logprobs(self.inner.top_logprobs)?;
193+
// validate::validate_max_tokens(self.inner.max_tokens)?; // warning depricated field
194+
validate::validate_max_completion_tokens(self.inner.max_completion_tokens)?;
195+
validate::validate_n(self.inner.n)?;
196+
// none for modalities
197+
// none for prediction
198+
// none for audio
199+
validate::validate_presence_penalty(self.inner.presence_penalty)?;
200+
// none for response_format
201+
// none for seed
202+
validate::validate_service_tier(&self.inner.service_tier)?;
203+
validate::validate_stop(&self.inner.stop)?;
204+
// none for stream
205+
// none for stream_options
206+
validate::validate_temperature(self.inner.temperature)?;
207+
validate::validate_top_p(self.inner.top_p)?;
208+
validate::validate_tools(&self.inner.tools.as_deref())?;
209+
// none for tool_choice
210+
// none for parallel_tool_calls
211+
validate::validate_user(self.inner.user.as_deref())?;
212+
// none for function call
213+
// none for functions
214+
215+
Ok(())
216+
}
217+
}

lib/llm/src/protocols/openai/completions.rs

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,12 @@ use dynamo_runtime::protocols::annotated::AnnotationsProvider;
1818
use serde::{Deserialize, Serialize};
1919
use validator::Validate;
2020

21+
use crate::engines::ValidateRequest;
22+
2123
use super::{
2224
common::{self, SamplingOptionsProvider, StopConditionsProvider},
2325
nvext::{NvExt, NvExtProvider},
24-
ContentProvider, OpenAISamplingOptionsProvider, OpenAIStopConditionsProvider,
26+
validate, ContentProvider, OpenAISamplingOptionsProvider, OpenAIStopConditionsProvider,
2527
};
2628

2729
mod aggregator;
@@ -275,3 +277,30 @@ impl TryFrom<common::StreamingCompletionResponse> for async_openai::types::Choic
275277
Ok(choice)
276278
}
277279
}
280+
281+
/// Implements `ValidateRequest` for `NvCreateCompletionRequest`,
282+
/// allowing us to validate the data.
283+
impl ValidateRequest for NvCreateCompletionRequest {
284+
fn validate(&self) -> Result<(), anyhow::Error> {
285+
validate::validate_model(&self.inner.model)?;
286+
validate::validate_prompt(&self.inner.prompt)?;
287+
validate::validate_suffix(self.inner.suffix.as_deref())?;
288+
validate::validate_max_tokens(self.inner.max_tokens)?;
289+
validate::validate_temperature(self.inner.temperature)?;
290+
validate::validate_top_p(self.inner.top_p)?;
291+
validate::validate_n(self.inner.n)?;
292+
// none for stream
293+
// none for stream_options
294+
validate::validate_logprobs(self.inner.logprobs)?;
295+
// none for echo
296+
validate::validate_stop(&self.inner.stop)?;
297+
validate::validate_presence_penalty(self.inner.presence_penalty)?;
298+
validate::validate_frequency_penalty(self.inner.frequency_penalty)?;
299+
validate::validate_best_of(self.inner.best_of, self.inner.n)?;
300+
validate::validate_logit_bias(&self.inner.logit_bias)?;
301+
validate::validate_user(self.inner.user.as_deref())?;
302+
// none for seed
303+
304+
Ok(())
305+
}
306+
}

0 commit comments

Comments
 (0)