diff --git a/Cargo.lock b/Cargo.lock index a95fc4e30..c13f10a57 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -712,6 +712,18 @@ dependencies = [ "syn 2.0.49", ] +[[package]] +name = "enum_dispatch" +version = "0.3.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa18ce2bc66555b3218614519ac839ddb759a7d6720732f979ef8d13be147ecd" +dependencies = [ + "once_cell", + "proc-macro2", + "quote", + "syn 2.0.49", +] + [[package]] name = "equivalent" version = "1.0.1" @@ -1964,6 +1976,7 @@ dependencies = [ "clap", "console-subscriber", "criterion", + "enum_dispatch", "figment", "futures-util", "http-body-util", diff --git a/Cargo.toml b/Cargo.toml index dd97b03b8..6a7e4cfc3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,6 +25,7 @@ ott-balancer = { path = "crates/ott-balancer" } ott-balancer-protocol = { path = "crates/ott-balancer-protocol" } pin-project = "1.0.12" prometheus = { version = "0.13.3", features = ["process"] } +enum_dispatch = "0.3.13" rand = "0.8.5" reqwest = { git = "https://github.com/seanmonstar/reqwest", rev = "2c11ef0", features = [ "json", diff --git a/crates/ott-balancer/Cargo.toml b/crates/ott-balancer/Cargo.toml index aab839244..92a7ae4d5 100644 --- a/crates/ott-balancer/Cargo.toml +++ b/crates/ott-balancer/Cargo.toml @@ -35,6 +35,7 @@ route-recognizer = "0.3.1" once_cell.workspace = true pin-project.workspace = true prometheus.workspace = true +enum_dispatch.workspace = true [dev-dependencies] criterion.workspace = true diff --git a/crates/ott-balancer/src/balancer.rs b/crates/ott-balancer/src/balancer.rs index fa82737fc..baa82fbec 100644 --- a/crates/ott-balancer/src/balancer.rs +++ b/crates/ott-balancer/src/balancer.rs @@ -17,11 +17,11 @@ use tracing::{debug, error, info, instrument, trace, warn}; use crate::balancer::collector::ClientState; use crate::client::ClientLink; -use crate::config::BalancerConfig; +use crate::config::{BalancerConfig, MonolithSelectionStrategy}; use crate::connection::BALANCER_ID; use crate::monolith::Room; use crate::room::RoomLocator; -use crate::selection::{MinRoomsSelector, MonolithSelection}; +use crate::selection::MonolithSelection; use crate::{ client::{BalancerClient, NewClient}, messages::*, @@ -186,27 +186,14 @@ impl BalancerLink { } } -#[derive(Debug)] +#[derive(Debug, Default)] pub struct BalancerContext { pub clients: HashMap, pub monoliths: HashMap, pub rooms_to_monoliths: HashMap, pub monoliths_by_region: HashMap>, - pub monolith_selection: Box, + pub monolith_selection: MonolithSelectionStrategy, } - -impl Default for BalancerContext { - fn default() -> Self { - BalancerContext { - clients: HashMap::default(), - monoliths: HashMap::default(), - rooms_to_monoliths: HashMap::default(), - monoliths_by_region: HashMap::default(), - monolith_selection: Box::::default(), - } - } -} - impl BalancerContext { pub fn new() -> Self { Default::default() diff --git a/crates/ott-balancer/src/config.rs b/crates/ott-balancer/src/config.rs index 5ee1a225e..b840e573c 100644 --- a/crates/ott-balancer/src/config.rs +++ b/crates/ott-balancer/src/config.rs @@ -1,15 +1,29 @@ use std::{borrow::BorrowMut, path::PathBuf, sync::Once}; use clap::{Parser, ValueEnum}; +use enum_dispatch::enum_dispatch; use figment::providers::Format; use serde::Deserialize; use ott_common::discovery::DiscoveryConfig; +use crate::selection::MinRoomsSelector; + static mut CONFIG: Option = None; static CONFIG_INIT: Once = Once::new(); +#[derive(Debug, Deserialize, Copy, Clone)] +#[enum_dispatch] +pub enum MonolithSelectionStrategy { + MinRooms(MinRoomsSelector), +} + +impl Default for MonolithSelectionStrategy { + fn default() -> Self { + MonolithSelectionStrategy::MinRooms(MinRoomsSelector) + } +} #[derive(Debug, Deserialize)] #[serde(default)] pub struct BalancerConfig { @@ -19,6 +33,7 @@ pub struct BalancerConfig { pub region: String, /// The API key that clients can use to access restricted endpoints. pub api_key: Option, + pub selection_strategy: Option, } impl Default for BalancerConfig { @@ -28,6 +43,7 @@ impl Default for BalancerConfig { discovery: DiscoveryConfig::default(), region: "unknown".to_owned(), api_key: None, + selection_strategy: None, } } } diff --git a/crates/ott-balancer/src/lib.rs b/crates/ott-balancer/src/lib.rs index f00b320e8..bc56b2bd5 100644 --- a/crates/ott-balancer/src/lib.rs +++ b/crates/ott-balancer/src/lib.rs @@ -96,7 +96,20 @@ pub async fn run() -> anyhow::Result<()> { let (discovery_tx, discovery_rx) = tokio::sync::mpsc::channel(2); info!("Starting balancer"); - let ctx = Arc::new(RwLock::new(BalancerContext::new())); + + let ctx = Arc::new(RwLock::new( + if let Some(selection_strategy) = config.selection_strategy { + info!("Using selection strategy: {:?}", selection_strategy); + BalancerContext { + monolith_selection: selection_strategy, + ..BalancerContext::new() + } + } else { + info!("Using default selection strategy"); + BalancerContext::new() + }, + )); + let balancer = Balancer::new(ctx.clone()); let service_link = balancer.new_link(); let conman_link = balancer.new_link(); diff --git a/crates/ott-balancer/src/selection.rs b/crates/ott-balancer/src/selection.rs index bf8f949f9..27af3a71b 100644 --- a/crates/ott-balancer/src/selection.rs +++ b/crates/ott-balancer/src/selection.rs @@ -1,9 +1,12 @@ -use rand::seq::IteratorRandom; - +use crate::config::MonolithSelectionStrategy; use crate::monolith::BalancerMonolith; +use enum_dispatch::enum_dispatch; +use rand::seq::IteratorRandom; +use serde::Deserialize; -#[derive(Debug, Default)] +#[derive(Debug, Default, Deserialize, Copy, Clone)] pub struct MinRoomsSelector; +#[enum_dispatch(MonolithSelectionStrategy)] pub trait MonolithSelection: std::fmt::Debug { fn select_monolith<'a>( &'a self,