Skip to content
This repository has been archived by the owner on Sep 10, 2024. It is now read-only.

Commit

Permalink
Automatically sync the configuration on server startup
Browse files Browse the repository at this point in the history
  • Loading branch information
sandhose committed Mar 1, 2024
1 parent 7200f94 commit c0a9d27
Show file tree
Hide file tree
Showing 5 changed files with 358 additions and 281 deletions.
305 changes: 25 additions & 280 deletions crates/cli/src/commands/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,75 +12,18 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::collections::HashSet;

use anyhow::Context;
use camino::Utf8PathBuf;
use clap::Parser;
use mas_config::{ConfigurationSection, RootConfig, SyncConfig};
use mas_storage::{
upstream_oauth2::{UpstreamOAuthProviderParams, UpstreamOAuthProviderRepository},
RepositoryAccess, SystemClock,
};
use mas_storage_pg::PgRepository;
use mas_storage::SystemClock;
use mas_storage_pg::MIGRATOR;
use rand::SeedableRng;
use sqlx::{postgres::PgAdvisoryLock, Acquire};
use tokio::io::AsyncWriteExt;
use tracing::{error, info, info_span, warn};
use tracing::{info, info_span, Instrument};

use crate::util::database_connection_from_config;

fn map_import_action(
config: &mas_config::UpstreamOAuth2ImportAction,
) -> mas_data_model::UpstreamOAuthProviderImportAction {
match config {
mas_config::UpstreamOAuth2ImportAction::Ignore => {
mas_data_model::UpstreamOAuthProviderImportAction::Ignore
}
mas_config::UpstreamOAuth2ImportAction::Suggest => {
mas_data_model::UpstreamOAuthProviderImportAction::Suggest
}
mas_config::UpstreamOAuth2ImportAction::Force => {
mas_data_model::UpstreamOAuthProviderImportAction::Force
}
mas_config::UpstreamOAuth2ImportAction::Require => {
mas_data_model::UpstreamOAuthProviderImportAction::Require
}
}
}

fn map_claims_imports(
config: &mas_config::UpstreamOAuth2ClaimsImports,
) -> mas_data_model::UpstreamOAuthProviderClaimsImports {
mas_data_model::UpstreamOAuthProviderClaimsImports {
subject: mas_data_model::UpstreamOAuthProviderSubjectPreference {
template: config.subject.template.clone(),
},
localpart: mas_data_model::UpstreamOAuthProviderImportPreference {
action: map_import_action(&config.localpart.action),
template: config.localpart.template.clone(),
},
displayname: mas_data_model::UpstreamOAuthProviderImportPreference {
action: map_import_action(&config.displayname.action),
template: config.displayname.template.clone(),
},
email: mas_data_model::UpstreamOAuthProviderImportPreference {
action: map_import_action(&config.email.action),
template: config.email.template.clone(),
},
verify_email: match config.email.set_email_verification {
mas_config::UpstreamOAuth2SetEmailVerification::Always => {
mas_data_model::UpsreamOAuthProviderSetEmailVerification::Always
}
mas_config::UpstreamOAuth2SetEmailVerification::Never => {
mas_data_model::UpsreamOAuthProviderSetEmailVerification::Never
}
mas_config::UpstreamOAuth2SetEmailVerification::Import => {
mas_data_model::UpsreamOAuthProviderSetEmailVerification::Import
}
},
}
}

#[derive(Parser, Debug)]
pub(super) struct Options {
#[command(subcommand)]
Expand Down Expand Up @@ -169,230 +112,32 @@ impl Options {
}

SC::Sync { prune, dry_run } => {
sync(root, prune, dry_run).await?;
}
}

Ok(())
}
}

#[tracing::instrument(name = "cli.config.sync", skip(root), err(Debug))]
async fn sync(root: &super::Options, prune: bool, dry_run: bool) -> anyhow::Result<()> {
// XXX: we should disallow SeedableRng::from_entropy
let clock = SystemClock::default();

let config: SyncConfig = root.load_config()?;
let encrypter = config.secrets.encrypter();
// Grab a connection to the database
let mut conn = database_connection_from_config(&config.database).await?;
// Start a transaction
let txn = conn.begin().await?;

// Grab a lock within the transaction
tracing::info!("Acquiring config lock");
let lock = PgAdvisoryLock::new("MAS config sync");
let lock = lock.acquire(txn).await?;

// Create a repository from the connection with the lock
let mut repo = PgRepository::from_conn(lock);

tracing::info!(
prune,
dry_run,
"Syncing providers and clients defined in config to database"
);

{
let _span = info_span!("cli.config.sync.providers").entered();
let config_ids = config
.upstream_oauth2
.providers
.iter()
.map(|p| p.id)
.collect::<HashSet<_>>();

let existing = repo.upstream_oauth_provider().all().await?;
let existing_ids = existing.iter().map(|p| p.id).collect::<HashSet<_>>();
let to_delete = existing.into_iter().filter(|p| !config_ids.contains(&p.id));
if prune {
for provider in to_delete {
info!(%provider.id, "Deleting provider");

if dry_run {
continue;
}

repo.upstream_oauth_provider().delete(provider).await?;
}
} else {
let len = to_delete.count();
match len {
0 => {},
1 => warn!("A provider in the database is not in the config. Run with `--prune` to delete it."),
n => warn!("{n} providers in the database are not in the config. Run with `--prune` to delete them."),
}
}

for provider in config.upstream_oauth2.providers {
let _span = info_span!("provider", %provider.id).entered();
if existing_ids.contains(&provider.id) {
info!("Updating provider");
} else {
info!("Adding provider");
}

if dry_run {
continue;
}

let encrypted_client_secret = provider
.client_secret()
.map(|client_secret| encrypter.encrypt_to_string(client_secret.as_bytes()))
.transpose()?;
let token_endpoint_auth_method = provider.client_auth_method();
let token_endpoint_signing_alg = provider.client_auth_signing_alg();

let discovery_mode = match provider.discovery_mode {
mas_config::UpstreamOAuth2DiscoveryMode::Oidc => {
mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc
}
mas_config::UpstreamOAuth2DiscoveryMode::Insecure => {
mas_data_model::UpstreamOAuthProviderDiscoveryMode::Insecure
}
mas_config::UpstreamOAuth2DiscoveryMode::Disabled => {
mas_data_model::UpstreamOAuthProviderDiscoveryMode::Disabled
}
};

if discovery_mode.is_disabled() {
if provider.authorization_endpoint.is_none() {
error!("Provider has discovery disabled but no authorization endpoint set");
}

if provider.token_endpoint.is_none() {
error!("Provider has discovery disabled but no token endpoint set");
}

if provider.jwks_uri.is_none() {
error!("Provider has discovery disabled but no JWKS URI set");
}
}

let pkce_mode = match provider.pkce_method {
mas_config::UpstreamOAuth2PkceMethod::Auto => {
mas_data_model::UpstreamOAuthProviderPkceMode::Auto
}
mas_config::UpstreamOAuth2PkceMethod::Always => {
mas_data_model::UpstreamOAuthProviderPkceMode::S256
}
mas_config::UpstreamOAuth2PkceMethod::Never => {
mas_data_model::UpstreamOAuthProviderPkceMode::Disabled
}
};

repo.upstream_oauth_provider()
.upsert(
let config: SyncConfig = root.load_config()?;
let clock = SystemClock::default();
let encrypter = config.secrets.encrypter();

// Grab a connection to the database
let mut conn = database_connection_from_config(&config.database).await?;

MIGRATOR
.run(&mut conn)
.instrument(info_span!("db.migrate"))
.await
.context("could not run migrations")?;

crate::sync::config_sync(
config.upstream_oauth2,
config.clients,
&mut conn,
&encrypter,
&clock,
provider.id,
UpstreamOAuthProviderParams {
issuer: provider.issuer,
human_name: provider.human_name,
brand_name: provider.brand_name,
scope: provider.scope.parse()?,
token_endpoint_auth_method,
token_endpoint_signing_alg,
client_id: provider.client_id,
encrypted_client_secret,
claims_imports: map_claims_imports(&provider.claims_imports),
token_endpoint_override: provider.token_endpoint,
authorization_endpoint_override: provider.authorization_endpoint,
jwks_uri_override: provider.jwks_uri,
discovery_mode,
pkce_mode,
additional_authorization_parameters: provider
.additional_authorization_parameters
.into_iter()
.collect(),
},
prune,
dry_run,
)
.await?;
}
}

{
let _span = info_span!("cli.config.sync.clients").entered();
let config_ids = config
.clients
.iter()
.map(|c| c.client_id)
.collect::<HashSet<_>>();

let existing = repo.oauth2_client().all_static().await?;
let existing_ids = existing.iter().map(|p| p.id).collect::<HashSet<_>>();
let to_delete = existing.into_iter().filter(|p| !config_ids.contains(&p.id));
if prune {
for client in to_delete {
info!(client.id = %client.client_id, "Deleting client");

if dry_run {
continue;
}

repo.oauth2_client().delete(client).await?;
}
} else {
let len = to_delete.count();
match len {
0 => {},
1 => warn!("A static client in the database is not in the config. Run with `--prune` to delete it."),
n => warn!("{n} static clients in the database are not in the config. Run with `--prune` to delete them."),
}
}

for client in config.clients.iter() {
let _span = info_span!("client", client.id = %client.client_id).entered();
if existing_ids.contains(&client.client_id) {
info!("Updating client");
} else {
info!("Adding client");
}

if dry_run {
continue;
}

let client_secret = client.client_secret();
let client_auth_method = client.client_auth_method();
let jwks = client.jwks();
let jwks_uri = client.jwks_uri();

// TODO: should be moved somewhere else
let encrypted_client_secret = client_secret
.map(|client_secret| encrypter.encrypt_to_string(client_secret.as_bytes()))
.transpose()?;

repo.oauth2_client()
.upsert_static(
client.client_id,
client_auth_method,
encrypted_client_secret,
jwks.cloned(),
jwks_uri.cloned(),
client.redirect_uris.clone(),
)
.await?;
}
}

// Get the lock and release it to commit the transaction
let lock = repo.into_inner();
let txn = lock.release_now().await?;
if dry_run {
info!("Dry run, rolling back changes");
txn.rollback().await?;
} else {
txn.commit().await?;
Ok(())
}
Ok(())
}
29 changes: 28 additions & 1 deletion crates/cli/src/commands/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use mas_handlers::{ActivityTracker, CookieManager, HttpClientFactory, MetadataCa
use mas_listener::{server::Server, shutdown::ShutdownStream};
use mas_matrix_synapse::SynapseConnection;
use mas_router::UrlBuilder;
use mas_storage::SystemClock;
use mas_storage_pg::MIGRATOR;
use rand::{
distributions::{Alphanumeric, DistString},
Expand All @@ -39,6 +40,7 @@ use crate::{
},
};

#[allow(clippy::struct_excessive_bools)]
#[derive(Parser, Debug, Default)]
pub(super) struct Options {
/// Do not apply pending migrations on start
Expand All @@ -53,6 +55,10 @@ pub(super) struct Options {
/// Do not start the task worker
#[arg(long)]
no_worker: bool,

/// Do not sync the configuration with the database
#[arg(long)]
no_sync: bool,
}

impl Options {
Expand Down Expand Up @@ -88,14 +94,35 @@ impl Options {
.context("could not run migrations")?;
}

let encrypter = config.secrets.encrypter();

if self.no_sync {
info!("Skipping configuration sync");
} else {
// Sync the configuration with the database
let mut conn = pool.acquire().await?;
let clients_config = root.load_config()?;
let upstream_oauth2_config = root.load_config()?;

crate::sync::config_sync(
upstream_oauth2_config,
clients_config,
&mut conn,
&encrypter,
&SystemClock::default(),
false,
false,
)
.await?;
}

// Initialize the key store
let key_store = config
.secrets
.key_store()
.await
.context("could not import keys from config")?;

let encrypter = config.secrets.encrypter();
let cookie_manager =
CookieManager::derive_from(config.http.public_base.clone(), &config.secrets.encryption);

Expand Down
Loading

0 comments on commit c0a9d27

Please sign in to comment.