Skip to content

Commit

Permalink
chore(clenaup): moved stuffs around
Browse files Browse the repository at this point in the history
  • Loading branch information
ssoudan committed May 17, 2023
1 parent fcb2e1c commit e6c2e6b
Show file tree
Hide file tree
Showing 3 changed files with 234 additions and 207 deletions.
50 changes: 50 additions & 0 deletions language/src/auth.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
//! Authentication related code.

use std::str::FromStr;

use tonic::Request;

use crate::{Credentials, Error};

/// API Key authorization
#[derive(Clone)]
pub struct APIKey {
api_key: String,
}

/// Authentication interceptor
#[derive(Clone)]
pub enum Authentication {
/// API Key authentication
APIKey(APIKey),
/// No authentication
None,
}

impl Authentication {
/// Build an authentication interceptor from the given credentials
pub async fn build(credentials: Credentials) -> Result<Authentication, Error> {
match credentials {
Credentials::ApiKey(api_key) => {
let authz = APIKey { api_key };

Ok(Authentication::APIKey(authz))
}
Credentials::None => Ok(Authentication::None),
}
}
}

impl tonic::service::Interceptor for Authentication {
fn call(&mut self, mut req: Request<()>) -> Result<Request<()>, tonic::Status> {
match self {
Authentication::APIKey(api_key_auth) => {
let api_key = api_key_auth.api_key.clone();
let api_key = FromStr::from_str(&api_key).unwrap();
req.metadata_mut().insert("x-goog-api-key", api_key);
Ok(req)
}
Authentication::None => Ok(req),
}
}
}
244 changes: 37 additions & 207 deletions language/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
//!
//! An async client library for GCP Vertex AI Generative models

use std::str::FromStr;
pub mod auth;

pub use auth::Authentication;
use google::ai::generativelanguage::v1beta2::discuss_service_client::DiscussServiceClient;
use google::ai::generativelanguage::v1beta2::model_service_client::ModelServiceClient;
use google::ai::generativelanguage::v1beta2::text_service_client::TextServiceClient;
use tonic::codegen::http::uri::InvalidUri;
use tonic::transport::{Certificate, Channel, ClientTlsConfig};
use tonic::Request;

/// Errors that can occur when using [LanguageClient].
#[derive(thiserror::Error, Debug)]
Expand All @@ -21,7 +21,7 @@ pub enum Error {
/// Invalid URI.
#[error("{0}")]
InvalidUri(#[from] InvalidUri),
/// Vizier service error.
/// Service error.
#[error("Status: {}", .0.message())]
Status(#[from] tonic::Status),
}
Expand All @@ -31,8 +31,10 @@ const CERTIFICATES: &str = include_str!("../certs/roots.pem");
/// Credentials to use to connect to the services
#[derive(Clone)]
pub enum Credentials {
/// API Key
/// API Key - see https://cloud.google.com/docs/authentication/api-keys
ApiKey(String),
/// No authentication
None,
}

/// google protos.
Expand Down Expand Up @@ -66,15 +68,23 @@ pub mod google {
/// Generative Language client.
#[derive(Clone)]
pub struct LanguageClient {
/// The Discuss service client.
pub discuss_service:
DiscussServiceClient<tonic::service::interceptor::InterceptedService<Channel, Authz>>,
/// The Model service client.
pub model_service:
ModelServiceClient<tonic::service::interceptor::InterceptedService<Channel, Authz>>,
/// The Text service client.
/// The Discuss service client. In particular, this client is used for
/// [`DiscussServiceClient::count_message_tokens`] and
/// [`DiscussServiceClient::generate_message`].
pub discuss_service: DiscussServiceClient<
tonic::service::interceptor::InterceptedService<Channel, Authentication>,
>,
/// The Model service client. Notably, this client is used for
/// [`ModelServiceClient::list_models`] and
/// [`ModelServiceClient::get_model`].
pub model_service: ModelServiceClient<
tonic::service::interceptor::InterceptedService<Channel, Authentication>,
>,
/// The Text service client. Notably, this client is used for
/// [`TextServiceClient::generate_text`],
/// and [`TextServiceClient::embed_text`].
pub text_service:
TextServiceClient<tonic::service::interceptor::InterceptedService<Channel, Authz>>,
TextServiceClient<tonic::service::interceptor::InterceptedService<Channel, Authentication>>,
}

impl LanguageClient {
Expand Down Expand Up @@ -103,25 +113,32 @@ impl LanguageClient {

let endpoint = format!("https://{endpoint}", endpoint = domain_name);

dbg!(&endpoint);
let channel = Channel::from_shared(endpoint)?
.user_agent("github.com/ssoudan/gcp-vertex-ai-generative-ai")?
.tls_config(tls_config)?
.connect_lazy();

Self::from_channel(credentials, channel).await
}

/// Creates a new LanguageClient from a Channel.
pub async fn from_channel(
credentials: Credentials,
channel: Channel,
) -> Result<LanguageClient, Error> {
let discuss_service = {
let authz = Authz::build(credentials.clone()).await?;
DiscussServiceClient::with_interceptor(channel.clone(), authz)
let auth = Authentication::build(credentials.clone()).await?;
DiscussServiceClient::with_interceptor(channel.clone(), auth)
};

let model_service = {
let authz = Authz::build(credentials.clone()).await?;
ModelServiceClient::with_interceptor(channel.clone(), authz)
let auth = Authentication::build(credentials.clone()).await?;
ModelServiceClient::with_interceptor(channel.clone(), auth)
};

let text_service = {
let authz = Authz::build(credentials).await?;
TextServiceClient::with_interceptor(channel, authz)
let auth = Authentication::build(credentials).await?;
TextServiceClient::with_interceptor(channel, auth)
};

Ok(Self {
Expand All @@ -132,195 +149,8 @@ impl LanguageClient {
}
}

/// API Key authorization
#[derive(Clone)]
pub struct APIKey {
api_key: String,
}

/// Authorization interceptor
#[derive(Clone)]
pub enum Authz {
/// API Key authorization
APIKey(APIKey),
}

impl Authz {
async fn build(credentials: Credentials) -> Result<Authz, Error> {
match credentials {
Credentials::ApiKey(api_key) => {
let authz = APIKey { api_key };

Ok(Authz::APIKey(authz))
}
}
}
}

impl tonic::service::Interceptor for Authz {
fn call(&mut self, mut req: Request<()>) -> Result<Request<()>, tonic::Status> {
match self {
Authz::APIKey(api_key_auth) => {
let api_key = api_key_auth.api_key.clone();
let api_key = FromStr::from_str(&api_key).unwrap();
req.metadata_mut().insert("x-goog-api-key", api_key);
Ok(req)
}
}
}
}

#[cfg(test)]
mod tests {

use crate::common::test_client;
use crate::google::ai::generativelanguage::v1beta2::{
CountMessageTokensRequest, EmbedTextRequest, GenerateMessageRequest, GenerateTextRequest,
ListModelsRequest, Message, MessagePrompt, TextPrompt,
};

#[tokio::test]
async fn it_list_models() {
let mut client = test_client().await;

let req = ListModelsRequest {
page_size: 3,
page_token: "".to_string(),
};

dbg!(&req);

let resp = client.model_service.list_models(req).await;

dbg!(&resp);

assert!(resp.is_ok());

let resp = resp.unwrap();
for m in resp.get_ref().models.iter() {
println!("Model: {}: {}", m.name, m.description);
}

assert!(!resp.get_ref().models.is_empty());
}

#[tokio::test]
async fn it_count_tokens() {
let mut client = test_client().await;

let req = CountMessageTokensRequest {
model: "models/chat-bison-001".to_string(),
prompt: Some(MessagePrompt {
context: "Hello".to_string(),
examples: vec![],
messages: vec![Message {
author: "1".to_string(),
content: "How are you today?".to_string(),
citation_metadata: None,
}],
}),
};

dbg!(&req);

let resp = client.discuss_service.count_message_tokens(req).await;

dbg!(&resp);

assert!(resp.is_ok());

let resp = resp.unwrap();
assert!(resp.get_ref().token_count > 0);
}

#[tokio::test]
async fn it_generates_discussions() {
let mut client = test_client().await;

let req = GenerateMessageRequest {
model: "models/chat-bison-001".to_string(),
prompt: Some(MessagePrompt {
context: "Hello".to_string(),
examples: vec![],
messages: vec![Message {
author: "1".to_string(),
content: "How are you today?".to_string(),
citation_metadata: None,
}],
}),
temperature: None,
candidate_count: None,
top_p: None,
top_k: None,
};

dbg!(&req);

let resp = client.discuss_service.generate_message(req).await;

dbg!(&resp);

assert!(resp.is_ok());

let resp = resp.unwrap();

dbg!(resp);
}

#[tokio::test]
async fn it_generates_text() {
let mut client = test_client().await;

let req = GenerateTextRequest {
model: "models/text-bison-001".to_string(),
prompt: Some(TextPrompt {
text: "Hello my dear".to_string(),
}),
temperature: None,
candidate_count: None,
max_output_tokens: None,
top_p: None,
top_k: None,
safety_settings: vec![],
stop_sequences: vec![],
};

dbg!(&req);

let resp = client.text_service.generate_text(req).await;

dbg!(&resp);

assert!(resp.is_ok());

let resp = resp.unwrap();

dbg!(resp);
}

#[tokio::test]
async fn it_embeds_text() {
let mut client = test_client().await;

let req = EmbedTextRequest {
model: "models/embedding-gecko-001".to_string(),

text: "Je pense donc...".to_string(),
};

dbg!(&req);

let resp = client.text_service.embed_text(req).await;

dbg!(&resp);

assert!(resp.is_ok());

let resp = resp.unwrap();

dbg!(resp);
}
}
mod test;

#[cfg(test)]
mod common {
Expand Down
Loading

0 comments on commit e6c2e6b

Please sign in to comment.