diff --git a/crates/tabby/src/main.rs b/crates/tabby/src/main.rs index e13f13849ade..0e0403602872 100644 --- a/crates/tabby/src/main.rs +++ b/crates/tabby/src/main.rs @@ -123,11 +123,11 @@ async fn main() { .unwrap_or_else(|err| fatal!("Scheduler failed due to '{}'", err)), #[cfg(feature = "ee")] Commands::WorkerCompletion(args) => { - worker::main(tabby_webserver::api::WorkerKind::Completion, args).await + worker::main(tabby_webserver::public::WorkerKind::Completion, args).await } #[cfg(feature = "ee")] Commands::WorkerChat(args) => { - worker::main(tabby_webserver::api::WorkerKind::Chat, args).await + worker::main(tabby_webserver::public::WorkerKind::Chat, args).await } } diff --git a/crates/tabby/src/worker.rs b/crates/tabby/src/worker.rs index ad095c8a682d..911158b7e9e4 100644 --- a/crates/tabby/src/worker.rs +++ b/crates/tabby/src/worker.rs @@ -2,7 +2,7 @@ use std::{env::consts::ARCH, net::IpAddr, sync::Arc}; use axum::{routing, Router}; use clap::Args; -use tabby_webserver::api::{HubClient, WorkerKind}; +use tabby_webserver::public::{HubClient, RegisterWorkerRequest, WorkerKind}; use tracing::info; use crate::{ @@ -94,17 +94,19 @@ impl WorkerContext { let cuda_devices = read_cuda_devices().unwrap_or_default(); Self { - client: tabby_webserver::api::create_client( + client: tabby_webserver::public::create_client( &args.url, &args.token, - kind, - args.port as i32, - args.model.to_owned(), - args.device.to_string(), - ARCH.to_string(), - cpu_info, - cpu_count as i32, - cuda_devices, + RegisterWorkerRequest { + kind, + port: args.port as i32, + name: args.model.to_owned(), + device: args.device.to_string(), + arch: ARCH.to_string(), + cpu_info, + cpu_count: cpu_count as i32, + cuda_devices, + }, ) .await, } diff --git a/ee/tabby-webserver/examples/update-schema.rs b/ee/tabby-webserver/examples/update-schema.rs index 40c340d15389..bf0152e4fdfb 100644 --- a/ee/tabby-webserver/examples/update-schema.rs +++ b/ee/tabby-webserver/examples/update-schema.rs @@ -1,6 +1,6 @@ use std::fs::write; -use tabby_webserver::create_schema; +use tabby_webserver::public::create_schema; fn main() { let schema = create_schema(); diff --git a/ee/tabby-webserver/src/handler.rs b/ee/tabby-webserver/src/handler.rs new file mode 100644 index 000000000000..24313ccee2eb --- /dev/null +++ b/ee/tabby-webserver/src/handler.rs @@ -0,0 +1,56 @@ +use std::sync::Arc; + +use axum::{ + extract::State, + http::Request, + middleware::{from_fn_with_state, Next}, + routing, Extension, Router, +}; +use hyper::Body; +use juniper_axum::{graphiql, graphql, playground}; +use tabby_common::api::{code::CodeSearch, event::RawEventLogger}; + +use crate::{ + hub, repositories, + schema::{create_schema, Schema, ServiceLocator}, + service::create_service_locator, + ui, +}; + +pub async fn attach_webserver( + api: Router, + ui: Router, + logger: Arc, + code: Arc, +) -> (Router, Router) { + let ctx = create_service_locator(logger, code).await; + let schema = Arc::new(create_schema()); + + let api = api + .layer(from_fn_with_state(ctx.clone(), distributed_tabby_layer)) + .route( + "/graphql", + routing::post(graphql::, Arc>).with_state(ctx.clone()), + ) + .route("/graphql", routing::get(playground("/graphql", None))) + .layer(Extension(schema)) + .route( + "/hub", + routing::get(hub::ws_handler).with_state(ctx.clone()), + ) + .nest("/repositories", repositories::routes(ctx.clone())); + + let ui = ui + .route("/graphiql", routing::get(graphiql("/graphql", None))) + .fallback(ui::handler); + + (api, ui) +} + +async fn distributed_tabby_layer( + State(ws): State>, + request: Request, + next: Next, +) -> axum::response::Response { + ws.worker().dispatch_request(request, next).await +} diff --git a/ee/tabby-webserver/src/hub/api.rs b/ee/tabby-webserver/src/hub/api.rs index 292b82907f07..61de89b2eecc 100644 --- a/ee/tabby-webserver/src/hub/api.rs +++ b/ee/tabby-webserver/src/hub/api.rs @@ -9,7 +9,7 @@ use tabby_common::api::{ use tokio_tungstenite::connect_async; use super::websocket::WebSocketTransport; -pub use crate::schema::worker::{RegisterWorkerError, Worker, WorkerKind}; +pub use crate::schema::worker::WorkerKind; #[tarpc::service] pub trait Hub { @@ -29,18 +29,7 @@ pub fn tracing_context() -> tarpc::context::Context { tarpc::context::current() } -pub async fn create_client( - addr: &str, - token: &str, - kind: WorkerKind, - port: i32, - name: String, - device: String, - arch: String, - cpu_info: String, - cpu_count: i32, - cuda_devices: Vec, -) -> HubClient { +pub async fn create_client(addr: &str, token: &str, request: RegisterWorkerRequest) -> HubClient { let request = Request::builder() .uri(format!("ws://{}/hub", addr)) .header("Host", addr) @@ -52,17 +41,7 @@ pub async fn create_client( .header("Content-Type", "application/json") .header( ®ISTER_WORKER_HEADER, - serde_json::to_string(&RegisterWorkerRequest { - kind, - port, - name, - device, - arch, - cpu_info, - cpu_count, - cuda_devices, - }) - .unwrap(), + serde_json::to_string(&request).unwrap(), ) .body(()) .unwrap(); @@ -121,19 +100,18 @@ impl CodeSearch for HubClient { } #[derive(Serialize, Deserialize)] -pub(crate) struct RegisterWorkerRequest { - pub(crate) kind: WorkerKind, - pub(crate) port: i32, - pub(crate) name: String, - pub(crate) device: String, - pub(crate) arch: String, - pub(crate) cpu_info: String, - pub(crate) cpu_count: i32, - pub(crate) cuda_devices: Vec, +pub struct RegisterWorkerRequest { + pub kind: WorkerKind, + pub port: i32, + pub name: String, + pub device: String, + pub arch: String, + pub cpu_info: String, + pub cpu_count: i32, + pub cuda_devices: Vec, } -pub(crate) static REGISTER_WORKER_HEADER: HeaderName = - HeaderName::from_static("x-tabby-register-worker"); +pub static REGISTER_WORKER_HEADER: HeaderName = HeaderName::from_static("x-tabby-register-worker"); impl Header for RegisterWorkerRequest { fn name() -> &'static axum::http::HeaderName { diff --git a/ee/tabby-webserver/src/hub/mod.rs b/ee/tabby-webserver/src/hub/mod.rs index 61208e6809c9..de459527c856 100644 --- a/ee/tabby-webserver/src/hub/mod.rs +++ b/ee/tabby-webserver/src/hub/mod.rs @@ -3,26 +3,20 @@ mod websocket; use std::{net::SocketAddr, sync::Arc}; +use api::{Hub, RegisterWorkerRequest}; use axum::{ extract::{ws::WebSocket, ConnectInfo, State, WebSocketUpgrade}, - headers::Header, response::IntoResponse, TypedHeader, }; use hyper::{Body, StatusCode}; use juniper_axum::extract::AuthBearer; -use tabby_common::api::{ - code::{CodeSearch, SearchResponse}, - event::RawEventLogger, -}; +use tabby_common::api::code::SearchResponse; use tarpc::server::{BaseChannel, Channel}; use tracing::warn; +use websocket::WebSocketTransport; -use self::websocket::WebSocketTransport; -use crate::{ - api::{Hub, RegisterWorkerRequest}, - schema::{worker::Worker, ServiceLocator}, -}; +use crate::schema::{worker::Worker, ServiceLocator}; pub(crate) async fn ws_handler( ws: WebSocketUpgrade, @@ -74,13 +68,13 @@ async fn handle_socket(state: Arc, socket: WebSocket, worker tokio::spawn(server.execute(imp.serve())).await.unwrap() } -pub struct HubImpl { +struct HubImpl { ctx: Arc, worker_addr: String, } impl HubImpl { - pub fn new(ctx: Arc, worker_addr: String) -> Self { + fn new(ctx: Arc, worker_addr: String) -> Self { Self { ctx, worker_addr } } } diff --git a/ee/tabby-webserver/src/lib.rs b/ee/tabby-webserver/src/lib.rs index b824065073cd..dc2e935bc35c 100644 --- a/ee/tabby-webserver/src/lib.rs +++ b/ee/tabby-webserver/src/lib.rs @@ -1,14 +1,19 @@ -// used by tabby workers. -pub use hub::api; -// used by examples/update-schema.rs -pub use schema::create_schema; - +mod handler; mod hub; mod repositories; mod schema; mod service; mod ui; +pub mod public { + pub use super::{ + handler::attach_webserver, + /* used by tabby workers (consumer of /hub api) */ + hub::api::{create_client, HubClient, RegisterWorkerRequest, WorkerKind}, + /* used by examples/update-schema.rs */ schema::create_schema, + }; +} + use std::sync::Arc; use axum::{ @@ -30,7 +35,7 @@ pub async fn attach_webserver( code: Arc, ) -> (Router, Router) { let ctx = create_service_locator(logger, code).await; - let schema = Arc::new(create_schema()); + let schema = Arc::new(schema::create_schema()); let api = api .layer(from_fn_with_state(ctx.clone(), distributed_tabby_layer))