Skip to content

Commit

Permalink
feat: add address arg for duckdb-server-rust (uwdata#649)
Browse files Browse the repository at this point in the history
* feat: add address arg for duckdb-server-rust

* parse --address directly to IpAddr type

* define default values for app as constants

* format
  • Loading branch information
kwonoh authored Jan 9, 2025
1 parent 79e0218 commit 3ba3a0e
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 8 deletions.
10 changes: 7 additions & 3 deletions packages/duckdb-server-rust/src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ async fn handle_get(
}
}

pub const DEFAULT_DB_PATH: &str = ":memory:";
pub const DEFAULT_CONNECTION_POOL_SIZE: u32 = 10;
pub const DEFAULT_CACHE_SIZE: usize = 1000;

#[axum::debug_handler]
async fn handle_post(
State(state): State<Arc<AppState>>,
Expand All @@ -49,10 +53,10 @@ pub fn app(
) -> Result<Router> {
// Database and state setup
let db = ConnectionPool::new(
dp_path.unwrap_or(":memory:"),
connection_pool_size.unwrap_or(10),
dp_path.unwrap_or(DEFAULT_DB_PATH),
connection_pool_size.unwrap_or(DEFAULT_CONNECTION_POOL_SIZE),
)?;
let cache = lru::LruCache::new(cache_size.unwrap_or(1000).try_into()?);
let cache = lru::LruCache::new(cache_size.unwrap_or(DEFAULT_CACHE_SIZE).try_into()?);

let state = Arc::new(AppState {
db: Box::new(db),
Expand Down
18 changes: 13 additions & 5 deletions packages/duckdb-server-rust/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,14 @@ use axum_server::tls_rustls::RustlsConfig;
use clap::Parser;
use listenfd::ListenFd;
use std::net::TcpListener;
use std::{net::Ipv4Addr, net::SocketAddr, path::PathBuf};
use std::{net::IpAddr, net::Ipv4Addr, net::SocketAddr, path::PathBuf};
use tokio::net;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};

use crate::app::DEFAULT_CACHE_SIZE;
use crate::app::DEFAULT_CONNECTION_POOL_SIZE;
use crate::app::DEFAULT_DB_PATH;

mod app;
mod bundle;
mod cache;
Expand All @@ -19,19 +23,23 @@ mod websocket;
#[command(version, about, long_about = None)]
struct Args {
/// Path of database file (e.g., "database.db". ":memory:" for in-memory database)
#[arg(default_value = ":memory:")]
#[arg(default_value = DEFAULT_DB_PATH)]
database: String,

/// HTTP Address
#[arg(short, long, default_value_t = Ipv4Addr::LOCALHOST.into())]
address: IpAddr,

/// HTTP Port
#[arg(short, long, default_value_t = 3000)]
port: u16,

/// Max connection pool size
#[arg(long, default_value_t = 10)]
#[arg(long, default_value_t = DEFAULT_CONNECTION_POOL_SIZE)]
connection_pool_size: u32,

/// Max number of cache entries
#[arg(long, default_value_t = 1000)]
#[arg(long, default_value_t = DEFAULT_CACHE_SIZE)]
cache_size: usize,
}

Expand Down Expand Up @@ -71,7 +79,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
}

// Listenfd setup
let addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), args.port);
let addr = SocketAddr::new(args.address, args.port);
let mut listenfd = ListenFd::from_env();
let listener = match listenfd.take_tcp_listener(0)? {
// if we are given a tcp listener on listen fd 0, we use that one
Expand Down

0 comments on commit 3ba3a0e

Please sign in to comment.