Skip to content

Commit

Permalink
Rework how connection creation is done
Browse files Browse the repository at this point in the history
Signed-off-by: Ryan Levick <ryan.levick@fermyon.com>
  • Loading branch information
rylev committed Sep 17, 2024
1 parent e236a7f commit 750ad04
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 98 deletions.
1 change: 0 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion crates/factor-sqlite/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ tracing = { workspace = true }

[dev-dependencies]
spin-factors-test = { path = "../factors-test" }
spin-sqlite = { path = "../sqlite" }
tokio = { version = "1", features = ["macros", "rt"] }

[lints]
Expand Down
30 changes: 15 additions & 15 deletions crates/factor-sqlite/src/host.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::collections::HashSet;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;

use async_trait::async_trait;
Expand All @@ -14,35 +14,28 @@ use crate::{Connection, ConnectionCreator};

pub struct InstanceState {
allowed_databases: Arc<HashSet<String>>,
/// A resource table of connections.
connections: table::Table<Box<dyn Connection>>,
get_connection_creator: ConnectionCreatorGetter,
/// A map from database label to connection creators.
connection_creators: HashMap<String, Arc<dyn ConnectionCreator>>,
}

impl InstanceState {
pub fn allowed_databases(&self) -> &HashSet<String> {
&self.allowed_databases
}
}

/// A function that takes a database label and returns a connection creator, if one exists.
pub type ConnectionCreatorGetter =
Arc<dyn Fn(&str) -> Option<Arc<dyn ConnectionCreator>> + Send + Sync>;

impl InstanceState {
/// Create a new `InstanceState`
///
/// Takes the list of allowed databases, and a function for getting a connection creator given a database label.
pub fn new(
allowed_databases: Arc<HashSet<String>>,
get_connection_creator: ConnectionCreatorGetter,
connection_creators: HashMap<String, Arc<dyn ConnectionCreator>>,
) -> Self {
Self {
allowed_databases,
connections: table::Table::new(256),
get_connection_creator,
connection_creators,
}
}

/// Get a connection for a given database label.
fn get_connection(
&self,
connection: Resource<v2::Connection>,
Expand All @@ -52,6 +45,11 @@ impl InstanceState {
.map(|conn| conn.as_ref())
.ok_or(v2::Error::InvalidConnection)
}

/// Get the set of allowed databases.
pub fn allowed_databases(&self) -> &HashSet<String> {
&self.allowed_databases
}
}

impl SelfInstanceBuilder for InstanceState {}
Expand All @@ -69,7 +67,9 @@ impl v2::HostConnection for InstanceState {
if !self.allowed_databases.contains(&database) {
return Err(v2::Error::AccessDenied);
}
let conn = (self.get_connection_creator)(&database)
let conn = self
.connection_creators
.get(&database)
.ok_or(v2::Error::NoSuchDatabase)?
.create_connection(&database)
.await?;
Expand Down
21 changes: 10 additions & 11 deletions crates/factor-sqlite/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,12 @@ impl Factor for SqliteFactor {
))
})
.collect::<anyhow::Result<HashMap<_, _>>>()?;
let get_connection_creator: host::ConnectionCreatorGetter =
Arc::new(move |label| connection_creators.get(label).cloned());

ensure_allowed_databases_are_configured(&allowed_databases, |label| {
get_connection_creator(label).is_some()
connection_creators.contains_key(label)
})?;

Ok(AppState::new(allowed_databases, get_connection_creator))
Ok(AppState::new(allowed_databases, connection_creators))
}

fn prepare<T: spin_factors::RuntimeFactors>(
Expand All @@ -84,10 +82,9 @@ impl Factor for SqliteFactor {
.get(ctx.app_component().id())
.cloned()
.unwrap_or_default();
let get_connection_creator = ctx.app_state().get_connection_creator.clone();
Ok(InstanceState::new(
allowed_databases,
get_connection_creator,
ctx.app_state().connection_creators.clone(),
))
}
}
Expand Down Expand Up @@ -132,19 +129,19 @@ pub const ALLOWED_DATABASES_KEY: MetadataKey<Vec<String>> = MetadataKey::new("da
pub struct AppState {
/// A map from component id to a set of allowed database labels.
allowed_databases: HashMap<String, Arc<HashSet<String>>>,
/// A function for mapping from database name to a connection creator.
get_connection_creator: host::ConnectionCreatorGetter,
/// A mapping from database label to a connection creator.
connection_creators: HashMap<String, Arc<dyn ConnectionCreator>>,
}

impl AppState {
/// Create a new `AppState`
pub fn new(
allowed_databases: HashMap<String, Arc<HashSet<String>>>,
get_connection_creator: host::ConnectionCreatorGetter,
connection_creators: HashMap<String, Arc<dyn ConnectionCreator>>,
) -> Self {
Self {
allowed_databases,
get_connection_creator,
connection_creators,
}
}

Expand All @@ -155,7 +152,9 @@ impl AppState {
&self,
label: &str,
) -> Option<Result<Box<dyn Connection>, v2::Error>> {
let connection = (self.get_connection_creator)(label)?
let connection = self
.connection_creators
.get(label)?
.create_connection(label)
.await;
Some(connection)
Expand Down
138 changes: 68 additions & 70 deletions crates/factor-sqlite/tests/factor_test.rs
Original file line number Diff line number Diff line change
@@ -1,132 +1,130 @@
use std::{collections::HashSet, sync::Arc};
use std::{
collections::{HashMap, HashSet},
sync::Arc,
};

use spin_factor_sqlite::SqliteFactor;
use spin_factor_sqlite::{RuntimeConfig, SqliteFactor};
use spin_factors::{
anyhow::{self, bail, Context},
runtime_config::toml::TomlKeyTracker,
Factor, FactorRuntimeConfigSource, RuntimeConfigSourceFinalizer, RuntimeFactors,
anyhow::{self, bail, Context as _},
RuntimeFactors,
};
use spin_factors_test::{toml, TestEnvironment};
use spin_sqlite::RuntimeConfigResolver;
use spin_world::async_trait;
use spin_world::{async_trait, v2::sqlite as v2};
use v2::HostConnection as _;

#[derive(RuntimeFactors)]
struct TestFactors {
sqlite: SqliteFactor,
}

#[tokio::test]
async fn sqlite_works() -> anyhow::Result<()> {
async fn errors_when_non_configured_database_used() -> anyhow::Result<()> {
let factors = TestFactors {
sqlite: SqliteFactor::new(),
};
let env = TestEnvironment::new(factors).extend_manifest(toml! {
[component.test-component]
source = "does-not-exist.wasm"
sqlite_databases = ["default"]
sqlite_databases = ["foo"]
});
let state = env.build_instance_state().await?;
let Err(err) = env.build_instance_state().await else {
bail!("Expected build_instance_state to error but it did not");
};

assert_eq!(
state.sqlite.allowed_databases(),
&["default".into()].into_iter().collect::<HashSet<_>>()
);
assert!(err
.to_string()
.contains("One or more components use SQLite databases which are not defined."));

Ok(())
}

#[tokio::test]
async fn errors_when_non_configured_database_used() -> anyhow::Result<()> {
async fn errors_when_database_not_allowed() -> anyhow::Result<()> {
let factors = TestFactors {
sqlite: SqliteFactor::new(),
};
let env = TestEnvironment::new(factors).extend_manifest(toml! {
[component.test-component]
source = "does-not-exist.wasm"
sqlite_databases = ["foo"]
sqlite_databases = []
});
let Err(err) = env.build_instance_state().await else {
bail!("Expected build_instance_state to error but it did not");
};
let mut state = env
.build_instance_state()
.await
.context("build_instance_state failed")?;

assert!(err
.to_string()
.contains("One or more components use SQLite databases which are not defined."));
assert!(matches!(
state.sqlite.open("foo".into()).await,
Err(spin_world::v2::sqlite::Error::AccessDenied)
));

Ok(())
}

#[tokio::test]
async fn no_error_when_database_is_configured() -> anyhow::Result<()> {
async fn it_works_when_database_is_configured() -> anyhow::Result<()> {
let factors = TestFactors {
sqlite: SqliteFactor::new(),
};
let runtime_config = toml! {
[sqlite_database.foo]
type = "spin"
let mut connection_creators = HashMap::new();
connection_creators.insert("foo".to_owned(), Arc::new(MockConnectionCreator) as _);
let runtime_config = TestFactorsRuntimeConfig {
sqlite: Some(RuntimeConfig {
connection_creators,
}),
};
let sqlite_config = RuntimeConfigResolver::new(None, "/".into());
let env = TestEnvironment::new(factors)
.extend_manifest(toml! {
[component.test-component]
source = "does-not-exist.wasm"
sqlite_databases = ["foo"]
})
.runtime_config(TomlRuntimeSource::new(&runtime_config, sqlite_config))?;
env.build_instance_state()
.runtime_config(runtime_config)?;

let mut state = env
.build_instance_state()
.await
.context("build_instance_state failed")?;
Ok(())
}

struct TomlRuntimeSource<'a> {
table: TomlKeyTracker<'a>,
runtime_config_resolver: RuntimeConfigResolver,
}

impl<'a> TomlRuntimeSource<'a> {
fn new(table: &'a toml::Table, runtime_config_resolver: RuntimeConfigResolver) -> Self {
Self {
table: TomlKeyTracker::new(table),
runtime_config_resolver,
}
}
}

impl FactorRuntimeConfigSource<SqliteFactor> for TomlRuntimeSource<'_> {
fn get_runtime_config(
&mut self,
) -> anyhow::Result<Option<<SqliteFactor as Factor>::RuntimeConfig>> {
self.runtime_config_resolver.resolve_from_toml(&self.table)
}
}
assert_eq!(
state.sqlite.allowed_databases(),
&["foo".into()].into_iter().collect::<HashSet<_>>()
);

impl RuntimeConfigSourceFinalizer for TomlRuntimeSource<'_> {
fn finalize(&mut self) -> anyhow::Result<()> {
self.table.validate_all_keys_used()?;
Ok(())
}
assert!(state.sqlite.open("foo".into()).await.is_ok());
Ok(())
}

impl TryFrom<TomlRuntimeSource<'_>> for TestFactorsRuntimeConfig {
type Error = anyhow::Error;
/// A connection creator that returns a mock connection.
struct MockConnectionCreator;

fn try_from(value: TomlRuntimeSource<'_>) -> Result<Self, Self::Error> {
Self::from_source(value)
#[async_trait]
impl spin_factor_sqlite::ConnectionCreator for MockConnectionCreator {
async fn create_connection(
&self,
label: &str,
) -> Result<Box<dyn spin_factor_sqlite::Connection + 'static>, v2::Error> {
let _ = label;
Ok(Box::new(MockConnection))
}
}

/// A connection creator that always returns an error.
struct InvalidConnectionCreator;
/// A mock connection that always errors.
struct MockConnection;

#[async_trait]
impl spin_factor_sqlite::ConnectionCreator for InvalidConnectionCreator {
async fn create_connection(
impl spin_factor_sqlite::Connection for MockConnection {
async fn query(
&self,
label: &str,
) -> Result<Box<dyn spin_factor_sqlite::Connection + 'static>, spin_world::v2::sqlite::Error>
{
let _ = label;
Err(spin_world::v2::sqlite::Error::InvalidConnection)
query: &str,
parameters: Vec<v2::Value>,
) -> Result<v2::QueryResult, v2::Error> {
let _ = (query, parameters);
Err(v2::Error::Io("Mock connection".into()))
}

async fn execute_batch(&self, statements: &str) -> anyhow::Result<()> {
let _ = statements;
bail!("Mock connection")
}
}

0 comments on commit 750ad04

Please sign in to comment.