Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

No connection reuse in SQLite Factor #2709

Merged
merged 2 commits into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 15 additions & 11 deletions crates/factor-sqlite/src/host.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ use spin_factors::{anyhow, SelfInstanceBuilder};
use spin_world::v1::sqlite as v1;
use spin_world::v2::sqlite as v2;

use crate::{Connection, ConnectionPool};
use crate::{Connection, ConnectionCreator};

pub struct InstanceState {
allowed_databases: Arc<HashSet<String>>,
connections: table::Table<Arc<dyn Connection>>,
get_pool: ConnectionPoolGetter,
connections: table::Table<Box<dyn Connection>>,
get_connection_creator: ConnectionCreatorGetter,
}

impl InstanceState {
Expand All @@ -22,25 +22,29 @@ impl InstanceState {
}
}

/// A function that takes a database label and returns a connection pool, if one exists.
pub type ConnectionPoolGetter = Arc<dyn Fn(&str) -> Option<Arc<dyn ConnectionPool>> + Send + Sync>;
/// A function that takes a database label and returns a connection creator, if one exists.
pub type ConnectionCreatorGetter =
Arc<dyn Fn(&str) -> Option<Arc<dyn ConnectionCreator>> + Send + Sync>;

impl InstanceState {
/// Create a new `InstanceState`
///
/// Takes the list of allowed databases, and a function for getting a connection pool given a database label.
pub fn new(allowed_databases: Arc<HashSet<String>>, get_pool: ConnectionPoolGetter) -> Self {
/// Takes the list of allowed databases, and a function for getting a connection creator given a database label.
pub fn new(
allowed_databases: Arc<HashSet<String>>,
get_connection_creator: ConnectionCreatorGetter,
) -> Self {
Self {
allowed_databases,
connections: table::Table::new(256),
get_pool,
get_connection_creator,
}
}

fn get_connection(
&self,
connection: Resource<v2::Connection>,
) -> Result<&Arc<dyn Connection>, v2::Error> {
) -> Result<&Box<dyn Connection>, v2::Error> {
self.connections
.get(connection.rep())
.ok_or(v2::Error::InvalidConnection)
Expand All @@ -61,9 +65,9 @@ impl v2::HostConnection for InstanceState {
if !self.allowed_databases.contains(&database) {
return Err(v2::Error::AccessDenied);
}
(self.get_pool)(&database)
(self.get_connection_creator)(&database)
.ok_or(v2::Error::NoSuchDatabase)?
.get_connection()
.create_connection()
.await
.and_then(|conn| {
self.connections
Expand Down
62 changes: 28 additions & 34 deletions crates/factor-sqlite/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ impl Factor for SqliteFactor {
&self,
mut ctx: spin_factors::ConfigureAppContext<T, Self>,
) -> anyhow::Result<Self::AppState> {
let connection_pools = ctx
let connection_creators = ctx
.take_runtime_config()
.map(|r| r.pools)
.map(|r| r.connection_creators)
.unwrap_or_default();

let allowed_databases = ctx
Expand All @@ -68,20 +68,20 @@ impl Factor for SqliteFactor {
})
.collect::<anyhow::Result<HashMap<_, _>>>()?;
let resolver = self.default_label_resolver.clone();
let get_connection_pool: host::ConnectionPoolGetter = Arc::new(move |label| {
connection_pools
let get_connection_creator: host::ConnectionCreatorGetter = Arc::new(move |label| {
connection_creators
.get(label)
.cloned()
.or_else(|| resolver.default(label))
});

ensure_allowed_databases_are_configured(&allowed_databases, |label| {
get_connection_pool(label).is_some()
get_connection_creator(label).is_some()
})?;

Ok(AppState {
allowed_databases,
get_connection_pool,
get_connection_creator,
})
}

Expand All @@ -96,8 +96,11 @@ impl Factor for SqliteFactor {
.get(ctx.app_component().id())
.cloned()
.unwrap_or_default();
let get_connection_pool = ctx.app_state().get_connection_pool.clone();
Ok(InstanceState::new(allowed_databases, get_connection_pool))
let get_connection_creator = ctx.app_state().get_connection_creator.clone();
Ok(InstanceState::new(
allowed_databases,
get_connection_creator,
))
}
}

Expand Down Expand Up @@ -136,46 +139,37 @@ fn ensure_allowed_databases_are_configured(

pub const ALLOWED_DATABASES_KEY: MetadataKey<Vec<String>> = MetadataKey::new("databases");

/// Resolves a label to a default connection pool.
/// Resolves a label to a default connection creator.
pub trait DefaultLabelResolver: Send + Sync {
/// If there is no runtime configuration for a given database label, return a default connection pool.
/// If there is no runtime configuration for a given database label, return a default connection creator.
///
/// If `Option::None` is returned, the database is not allowed.
fn default(&self, label: &str) -> Option<Arc<dyn ConnectionPool>>;
fn default(&self, label: &str) -> Option<Arc<dyn ConnectionCreator>>;
}

pub struct AppState {
/// A map from component id to a set of allowed database labels.
allowed_databases: HashMap<String, Arc<HashSet<String>>>,
/// A function for mapping from database name to a connection pool
get_connection_pool: host::ConnectionPoolGetter,
/// A function for mapping from database name to a connection creator.
get_connection_creator: host::ConnectionCreatorGetter,
}

/// A pool of connections for a particular SQLite database
/// A creator of a connections for a particular SQLite database.
#[async_trait]
pub trait ConnectionPool: Send + Sync {
/// Get a `Connection` from the pool
async fn get_connection(&self) -> Result<Arc<dyn Connection + 'static>, v2::Error>;
}

/// A simple [`ConnectionPool`] that always creates a new connection.
pub struct SimpleConnectionPool(
Box<dyn Fn() -> anyhow::Result<Arc<dyn Connection + 'static>> + Send + Sync>,
);

impl SimpleConnectionPool {
/// Create a new `SimpleConnectionPool` with the given connection factory.
pub fn new(
factory: impl Fn() -> anyhow::Result<Arc<dyn Connection + 'static>> + Send + Sync + 'static,
) -> Self {
Self(Box::new(factory))
}
pub trait ConnectionCreator: Send + Sync {
/// Get a *new* [`Connection`]
///
/// The connection should be a new connection, not a reused one.
async fn create_connection(&self) -> Result<Box<dyn Connection + 'static>, v2::Error>;
}

#[async_trait::async_trait]
impl ConnectionPool for SimpleConnectionPool {
async fn get_connection(&self) -> Result<Arc<dyn Connection + 'static>, v2::Error> {
(self.0)().map_err(|_| v2::Error::InvalidConnection)
impl<F> ConnectionCreator for F
where
F: Fn() -> anyhow::Result<Box<dyn Connection + 'static>> + Send + Sync + 'static,
{
async fn create_connection(&self) -> Result<Box<dyn Connection + 'static>, v2::Error> {
(self)().map_err(|_| v2::Error::InvalidConnection)
}
}

Expand Down
6 changes: 3 additions & 3 deletions crates/factor-sqlite/src/runtime_config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@ pub mod spin;

use std::{collections::HashMap, sync::Arc};

use crate::ConnectionPool;
use crate::ConnectionCreator;

/// A runtime configuration for SQLite databases.
///
/// Maps database labels to connection pools.
/// Maps database labels to connection creators.
pub struct RuntimeConfig {
pub pools: HashMap<String, Arc<dyn ConnectionPool>>,
pub connection_creators: HashMap<String, Arc<dyn ConnectionCreator>>,
}
51 changes: 28 additions & 23 deletions crates/factor-sqlite/src/runtime_config/spin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use spin_factors::{
use spin_world::v2::sqlite as v2;
use tokio::sync::OnceCell;

use crate::{Connection, ConnectionPool, DefaultLabelResolver, SimpleConnectionPool};
use crate::{Connection, ConnectionCreator, DefaultLabelResolver};

/// Spin's default handling of the runtime configuration for SQLite databases.
///
Expand Down Expand Up @@ -66,28 +66,34 @@ impl SpinSqliteRuntimeConfig {
return Ok(None);
};
let config: std::collections::HashMap<String, RuntimeConfig> = table.clone().try_into()?;
let pools = config
let connection_creators = config
.into_iter()
.map(|(k, v)| Ok((k, self.get_pool(v)?)))
.map(|(k, v)| Ok((k, self.get_connection_creator(v)?)))
.collect::<anyhow::Result<_>>()?;
Ok(Some(super::RuntimeConfig { pools }))
Ok(Some(super::RuntimeConfig {
connection_creators,
}))
}

/// Get a connection pool for a given runtime configuration.
pub fn get_pool(&self, config: RuntimeConfig) -> anyhow::Result<Arc<dyn ConnectionPool>> {
/// Get a connection creator for a given runtime configuration.
pub fn get_connection_creator(
&self,
config: RuntimeConfig,
) -> anyhow::Result<Arc<dyn ConnectionCreator>> {
let database_kind = config.type_.as_str();
let pool = match database_kind {
match database_kind {
"spin" => {
let config: LocalDatabase = config.config.try_into()?;
config.pool(&self.local_database_dir)?
Ok(Arc::new(
config.connection_creator(&self.local_database_dir)?,
))
}
"libsql" => {
let config: LibSqlDatabase = config.config.try_into()?;
config.pool()?
Ok(Arc::new(config.connection_creator()?))
}
_ => anyhow::bail!("Unknown database kind: {database_kind}"),
};
Ok(Arc::new(pool))
}
}
}

Expand All @@ -100,7 +106,7 @@ pub struct RuntimeConfig {
}

impl DefaultLabelResolver for SpinSqliteRuntimeConfig {
fn default(&self, label: &str) -> Option<Arc<dyn ConnectionPool>> {
fn default(&self, label: &str) -> Option<Arc<dyn ConnectionCreator>> {
// Only default the database labeled "default".
if label != "default" {
return None;
Expand All @@ -110,10 +116,9 @@ impl DefaultLabelResolver for SpinSqliteRuntimeConfig {
let factory = move || {
let location = spin_sqlite_inproc::InProcDatabaseLocation::Path(path.clone());
let connection = spin_sqlite_inproc::InProcConnection::new(location)?;
Ok(Arc::new(connection) as _)
Ok(Box::new(connection) as _)
};
let pool = SimpleConnectionPool::new(factory);
Some(Arc::new(pool))
Some(Arc::new(factory))
}
}

Expand Down Expand Up @@ -196,10 +201,10 @@ pub struct LocalDatabase {
}

impl LocalDatabase {
/// Create a new connection pool for a local database.
/// Get a new connection creator for a local database.
///
/// `base_dir` is the base directory path from which `path` is resolved if it is a relative path.
fn pool(self, base_dir: &Path) -> anyhow::Result<SimpleConnectionPool> {
fn connection_creator(self, base_dir: &Path) -> anyhow::Result<impl ConnectionCreator> {
let location = match self.path {
Some(path) => {
let path = resolve_relative_path(&path, base_dir);
Expand All @@ -213,9 +218,9 @@ impl LocalDatabase {
};
let factory = move || {
let connection = spin_sqlite_inproc::InProcConnection::new(location.clone())?;
Ok(Arc::new(connection) as _)
Ok(Box::new(connection) as _)
};
Ok(SimpleConnectionPool::new(factory))
Ok(factory)
}
}

Expand All @@ -238,8 +243,8 @@ pub struct LibSqlDatabase {
}

impl LibSqlDatabase {
/// Create a new connection pool for a libSQL database.
fn pool(self) -> anyhow::Result<SimpleConnectionPool> {
/// Get a new connection creator for a libSQL database.
fn connection_creator(self) -> anyhow::Result<impl ConnectionCreator> {
let url = check_url(&self.url)
.with_context(|| {
format!(
Expand All @@ -250,9 +255,9 @@ impl LibSqlDatabase {
.to_owned();
let factory = move || {
let connection = LibSqlConnection::new(url.clone(), self.token.clone());
Ok(Arc::new(connection) as _)
Ok(Box::new(connection) as _)
};
Ok(SimpleConnectionPool::new(factory))
Ok(factory)
}
}

Expand Down
16 changes: 8 additions & 8 deletions crates/factor-sqlite/tests/factor_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ impl TryFrom<TomlRuntimeSource<'_>> for TestFactorsRuntimeConfig {
}
}

/// Will return an `InvalidConnectionPool` for the supplied default database.
/// Will return an `InvalidConnectionCreator` for the supplied default database.
struct DefaultLabelResolver {
default: Option<String>,
}
Expand All @@ -130,22 +130,22 @@ impl DefaultLabelResolver {
}

impl factor_sqlite::DefaultLabelResolver for DefaultLabelResolver {
fn default(&self, label: &str) -> Option<Arc<dyn factor_sqlite::ConnectionPool>> {
fn default(&self, label: &str) -> Option<Arc<dyn factor_sqlite::ConnectionCreator>> {
let Some(default) = &self.default else {
return None;
};
(default == label).then_some(Arc::new(InvalidConnectionPool))
(default == label).then_some(Arc::new(InvalidConnectionCreator))
}
}

/// A connection pool that always returns an error.
struct InvalidConnectionPool;
/// A connection creator that always returns an error.
struct InvalidConnectionCreator;

#[async_trait::async_trait]
impl factor_sqlite::ConnectionPool for InvalidConnectionPool {
async fn get_connection(
impl factor_sqlite::ConnectionCreator for InvalidConnectionCreator {
async fn create_connection(
&self,
) -> Result<Arc<dyn factor_sqlite::Connection + 'static>, spin_world::v2::sqlite::Error> {
) -> Result<Box<dyn factor_sqlite::Connection + 'static>, spin_world::v2::sqlite::Error> {
Err(spin_world::v2::sqlite::Error::InvalidConnection)
}
}
Loading