diff --git a/docker-compose.yml b/docker-compose.yml index 991df2d01..be68a9c4e 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -8,3 +8,27 @@ services: - ./docker/sql_setup.sh:/docker-entrypoint-initdb.d/sql_setup.sh environment: POSTGRES_PASSWORD: postgres + postgres_replica: + image: docker.io/postgres:17 + user: postgres + ports: + - 5434:5433 + volumes: + - ./docker/sql_setup_replica.sh:/docker-entrypoint-initdb.d/sql_setup.sh + environment: + PGUSER: replicator + PGPASSWORD: replicator_password + command: | + bash -c " + until pg_basebackup --pgdata=/var/lib/postgresql/data -R --slot=replication_slot -C --host=postgres --port=5433 + do + echo 'Waiting for primary to connect...' + sleep 1s + done + echo 'Backup done, starting replica...' + chmod 0700 /var/lib/postgresql/data + /docker-entrypoint-initdb.d/sql_setup_replica.sh + postgres + " + depends_on: + - postgres diff --git a/docker/sql_setup.sh b/docker/sql_setup.sh index 0315ac805..26acf4edf 100755 --- a/docker/sql_setup.sh +++ b/docker/sql_setup.sh @@ -80,6 +80,8 @@ hostssl all ssl_user ::0/0 trust host all ssl_user 0.0.0.0/0 reject host all ssl_user ::0/0 reject +host replication replicator 0.0.0.0/0 trust + # IPv4 local connections: host all postgres 0.0.0.0/0 trust # IPv6 local connections: @@ -91,6 +93,7 @@ EOCONF psql -v ON_ERROR_STOP=1 --username "$POSTGRES_USER" <<-EOSQL CREATE ROLE pass_user PASSWORD 'password' LOGIN; CREATE ROLE md5_user PASSWORD 'password' LOGIN; + CREATE ROLE replicator WITH REPLICATION PASSWORD 'password' LOGIN; SET password_encryption TO 'scram-sha-256'; CREATE ROLE scram_user PASSWORD 'password' LOGIN; CREATE ROLE ssl_user LOGIN; diff --git a/docker/sql_setup_replica.sh b/docker/sql_setup_replica.sh new file mode 100644 index 000000000..9ac175332 --- /dev/null +++ b/docker/sql_setup_replica.sh @@ -0,0 +1,91 @@ +#!/bin/bash +set -e + +cat > "$PGDATA/server.key" <<-EOKEY +-----BEGIN RSA PRIVATE KEY----- +MIIEpAIBAAKCAQEAllItXwrj62MkxKVlz2FimJk42WWc3K82Rn2vAl6z38zQxSCj +t9uWwXWTx5YOdGiUcA+JUAruZxqN7vdfphJoYtTrcrpT4rC/FsCMImBxkj1cxdYT +q94SFn9bQBRZk7RUx4Kolt+/h0d3PpNIb4DbyQ8A0MVvNVxLpRRVwc6yQP+NkRMy +gHR+m3P8fxHEtkHCVy7HORbASvN8fRlREMHDL2hkadX0BNM72DDo+DWhPA8GF6WX +tIl1gU6GP6pSbEeMHD3f+uj7f9iSjvkrHrOt2nLUQ9Qnev2nhmU0/dOIweQ17/Fr +lL9jYDUUFNORyjRnlXXUoP5BO/LdEAAqT2A0pwIDAQABAoIBAQCIXu74XUneHuiZ +Wa+eTqwC4mZXmz6OWonzs0vU65NlgksXuv+r6ZO/2GoD1Bcy9jlL3Fxm+DPF56pB +07u7TtHSb3VWdMFrU4tYGcBH45TE5dRHSmo4LlPcgxeGb6/ANwX+pYNKtJvuHyCH +7Vf2iEFcCrdjrumv0BZ0IZmXJGxEV+7mK2Og0bZ/zbmJNaH25muuWj6BKlvLhL0N +S2LlBjKx3HqtppUgUqNFqjLs6IA1u79S5dAomOsxZtnuByaX5WFzpktU2pveZmyF +cl0dwHYZIaxR3ewYeQXGF8ANUmIx3nnxD2JOysPkitaGzeqt6dQZV14tPlDZDKat +Vf0b6BHhAoGBAMWV7rG+7nVXoQ30CIcPGklkST3mVOlrzeBbKP1SeAwoGRbfsdhp +rFMkh5UxTexnOzD4O8HPuJ6NGeWRQfqZT1nnjwHPeJWtiMHT6cnWxlzvxAZ61mio +0jRfb8flhgFKk+G9+Xa6WaYAAwGWdF062EMe2Ym92oKM9ilTPGFVRk1XAoGBAMLD +ETSQd2UqTF/y7wxMPqF3l6d1KBjwpuNuin2IjkXTOfGkDnAU3mSQlr7K1IPX8NPO +gdyMfJoysfRaBuRcNA/o/0l0wyxW4HWtTtPYI0+pRCFtRLsI1MB997QKeaGKb+me +3nBXkOksPSr9oa0Cs27z2cSoBOkpq2N/zzBseHExAoGAOyq3rKBZNehEwTHnb9I0 +8+9FA3U6zh9LKjkCIEGW00Uapj/cOMsEIG2a8DEwfW84SWS8OEBkr43fSGBkGo/Y +NDrkFw2ytVee0TQNGTTod6IQ2EPmera7I5XEml5/71kOyZWi40vQVqZAQDR2qgha +BFdzmwywJ1Hg0OUs+pSXlccCgYEAgyOVki80NYolovWQwFcWVOKR2s+oECL6PGlS +FvS714hCm9I7ZnymwlAZMJ6iOaRNJFEIX9i4jZtU95Mm0NzEsXHRc0SLpm9Y8+Oe +EEaYgCsZFOjePpHTr0kiYLgs7fipIkU2wa40hMyk4y2kjzoiV7MaDrCTnevQ205T +0+c1sgECgYBAXKcwdkh9JVSrLXFamsxiOx3MZ0n6J1d28wpdA3y4Y4AAJm4TGgFt +eG/6qHRy6CHdFtJ7a84EMe1jaVLQJYW/VrOC2bWLftkU7qaOnkXHvr4CAHsXQHcx +JhLfvh4ab3KyoK/iimifvcoS5z9gp7IBFKMyh5IeJ9Y75TgcfJ5HMg== +-----END RSA PRIVATE KEY----- +EOKEY +chmod 0600 "$PGDATA/server.key" + +cat > "$PGDATA/server.crt" <<-EOCERT +-----BEGIN CERTIFICATE----- +MIID9DCCAtygAwIBAgIJAIYfg4EQ2pVAMA0GCSqGSIb3DQEBBQUAMFkxCzAJBgNV +BAYTAkFVMRMwEQYDVQQIEwpTb21lLVN0YXRlMSEwHwYDVQQKExhJbnRlcm5ldCBX +aWRnaXRzIFB0eSBMdGQxEjAQBgNVBAMTCWxvY2FsaG9zdDAeFw0xNjA2MjgyMjQw +NDFaFw0yNjA2MjYyMjQwNDFaMFkxCzAJBgNVBAYTAkFVMRMwEQYDVQQIEwpTb21l +LVN0YXRlMSEwHwYDVQQKExhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQxEjAQBgNV +BAMTCWxvY2FsaG9zdDCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAJZS +LV8K4+tjJMSlZc9hYpiZONllnNyvNkZ9rwJes9/M0MUgo7fblsF1k8eWDnRolHAP +iVAK7mcaje73X6YSaGLU63K6U+KwvxbAjCJgcZI9XMXWE6veEhZ/W0AUWZO0VMeC +qJbfv4dHdz6TSG+A28kPANDFbzVcS6UUVcHOskD/jZETMoB0fptz/H8RxLZBwlcu +xzkWwErzfH0ZURDBwy9oZGnV9ATTO9gw6Pg1oTwPBhell7SJdYFOhj+qUmxHjBw9 +3/ro+3/Yko75Kx6zrdpy1EPUJ3r9p4ZlNP3TiMHkNe/xa5S/Y2A1FBTTkco0Z5V1 +1KD+QTvy3RAAKk9gNKcCAwEAAaOBvjCBuzAdBgNVHQ4EFgQUEcuoFxzUZ4VV9VPv +5frDyIuFA5cwgYsGA1UdIwSBgzCBgIAUEcuoFxzUZ4VV9VPv5frDyIuFA5ehXaRb +MFkxCzAJBgNVBAYTAkFVMRMwEQYDVQQIEwpTb21lLVN0YXRlMSEwHwYDVQQKExhJ +bnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQxEjAQBgNVBAMTCWxvY2FsaG9zdIIJAIYf +g4EQ2pVAMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQEFBQADggEBAHwMzmXdtz3R +83HIdRQic40bJQf9ucSwY5ArkttPhC8ewQGyiGexm1Tvx9YA/qT2rscKPHXCPYcP +IUE+nJTc8lQb8wPnFwGdHUsJfCvurxE4Yv4Oi74+q1enhHBGsvhFdFY5jTYD9unM +zBEn+ZHX3PlKhe3wMub4khBTbPLK+n/laQWuZNsa+kj7BynkAg8W/6RK0Z0cJzzw +aiVP0bSvatAAcSwkEfKEv5xExjWqoewjSlQLEZYIjJhXdtx/8AMnrcyxrFvKALUQ +9M15FXvlPOB7ez14xIXQBKvvLwXvteHF6kYbzg/Bl1Q2GE9usclPa4UvTpnLv6gq +NmFaAhoxnXA= +-----END CERTIFICATE----- +EOCERT + +cat >> "$PGDATA/postgresql.conf" <<-EOCONF +port = 5433 +ssl = on +ssl_cert_file = 'server.crt' +ssl_key_file = 'server.key' +EOCONF + +cat > "$PGDATA/pg_hba.conf" <<-EOCONF +# TYPE DATABASE USER ADDRESS METHOD +host all pass_user 0.0.0.0/0 password +host all md5_user 0.0.0.0/0 md5 +host all scram_user 0.0.0.0/0 scram-sha-256 +host all pass_user ::0/0 password +host all md5_user ::0/0 md5 +host all scram_user ::0/0 scram-sha-256 + +hostssl all ssl_user 0.0.0.0/0 trust +hostssl all ssl_user ::0/0 trust +host all ssl_user 0.0.0.0/0 reject +host all ssl_user ::0/0 reject + +host replication replicator 0.0.0.0/0 trust + +# IPv4 local connections: +host all postgres 0.0.0.0/0 trust +# IPv6 local connections: +host all postgres ::0/0 trust +# Unix socket connections: +local all postgres trust +EOCONF diff --git a/tokio-postgres/src/config.rs b/tokio-postgres/src/config.rs index 7ba5638e3..39a27f67a 100644 --- a/tokio-postgres/src/config.rs +++ b/tokio-postgres/src/config.rs @@ -36,6 +36,12 @@ pub enum TargetSessionAttrs { ReadWrite, /// The session allow only reads. ReadOnly, + /// The session allow primary node. + Primary, + /// The session allow standby node. + Standby, + /// The session prefers the standby node. + PreferStandby, } /// TLS configuration. @@ -677,6 +683,9 @@ impl Config { "any" => TargetSessionAttrs::Any, "read-write" => TargetSessionAttrs::ReadWrite, "read-only" => TargetSessionAttrs::ReadOnly, + "primary" => TargetSessionAttrs::Primary, + "standby" => TargetSessionAttrs::Standby, + "prefer-standby" => TargetSessionAttrs::PreferStandby, _ => { return Err(Error::config_parse(Box::new(InvalidValue( "target_session_attrs", diff --git a/tokio-postgres/src/connect.rs b/tokio-postgres/src/connect.rs index e97a7a2a3..1e0357462 100644 --- a/tokio-postgres/src/connect.rs +++ b/tokio-postgres/src/connect.rs @@ -105,8 +105,8 @@ where } let mut last_err = None; - for addr in addrs { - match connect_once(Addr::Tcp(addr.ip()), hostname.as_deref(), port, tls, config) + for addr in &addrs { + match connect_once(Addr::Tcp(addr.ip()), hostname.as_deref(), port, tls, config, config.target_session_attrs) .await { Ok(stream) => return Ok(stream), @@ -117,6 +117,21 @@ where }; } + // If initial pass wtih prefer standby failed then consider write hosts + if config.target_session_attrs == TargetSessionAttrs::PreferStandby { + for addr in &addrs { + match connect_once(Addr::Tcp(addr.ip()), hostname.as_deref(), port, tls, config, TargetSessionAttrs::Any) + .await + { + Ok(stream) => return Ok(stream), + Err(e) => { + last_err = Some(e); + continue; + } + }; + } + } + Err(last_err.unwrap_or_else(|| { Error::connect(io::Error::new( io::ErrorKind::InvalidInput, @@ -126,7 +141,7 @@ where } #[cfg(unix)] Host::Unix(path) => { - connect_once(Addr::Unix(path), hostname.as_deref(), port, tls, config).await + connect_once(Addr::Unix(path), hostname.as_deref(), port, tls, config, config.target_session_attrs).await } } } @@ -137,6 +152,7 @@ async fn connect_once( port: u16, tls: &mut T, config: &Config, + target_session_attrs: TargetSessionAttrs, ) -> Result<(Client, Connection), Error> where T: MakeTlsConnect, @@ -160,7 +176,7 @@ where let has_hostname = hostname.is_some(); let (mut client, mut connection) = connect_raw(socket, tls, has_hostname, config).await?; - if config.target_session_attrs != TargetSessionAttrs::Any { + if target_session_attrs == TargetSessionAttrs::ReadOnly || target_session_attrs == TargetSessionAttrs::ReadWrite { let rows = client.simple_query_raw("SHOW transaction_read_only"); pin_mut!(rows); @@ -187,14 +203,14 @@ where Some(SimpleQueryMessage::Row(row)) => { let read_only_result = row.try_get(0)?; if read_only_result == Some("on") - && config.target_session_attrs == TargetSessionAttrs::ReadWrite + && target_session_attrs == TargetSessionAttrs::ReadWrite { return Err(Error::connect(io::Error::new( io::ErrorKind::PermissionDenied, "database does not allow writes", ))); } else if read_only_result == Some("off") - && config.target_session_attrs == TargetSessionAttrs::ReadOnly + && target_session_attrs == TargetSessionAttrs::ReadOnly { return Err(Error::connect(io::Error::new( io::ErrorKind::PermissionDenied, @@ -208,6 +224,55 @@ where None => return Err(Error::unexpected_message()), } } + } else if target_session_attrs == TargetSessionAttrs::Primary || target_session_attrs == TargetSessionAttrs::Standby || target_session_attrs == TargetSessionAttrs::PreferStandby { + let rows = client.simple_query_raw("SELECT pg_catalog.pg_is_in_recovery()"); + pin_mut!(rows); + + let rows = future::poll_fn(|cx| { + if connection.poll_unpin(cx)?.is_ready() { + return Poll::Ready(Err(Error::closed())); + } + + rows.as_mut().poll(cx) + }) + .await?; + pin_mut!(rows); + + loop { + let next = future::poll_fn(|cx| { + if connection.poll_unpin(cx)?.is_ready() { + return Poll::Ready(Some(Err(Error::closed()))); + } + + rows.as_mut().poll_next(cx) + }); + + match next.await.transpose()? { + Some(SimpleQueryMessage::Row(row)) => { + let primary_result = row.try_get(0)?; + println!("{:?}", primary_result); + if primary_result == Some("t") + && target_session_attrs == TargetSessionAttrs::Primary + { + return Err(Error::connect(io::Error::new( + io::ErrorKind::PermissionDenied, + "database is not primary", + ))); + } else if primary_result == Some("f") + && target_session_attrs == TargetSessionAttrs::Standby + { + return Err(Error::connect(io::Error::new( + io::ErrorKind::PermissionDenied, + "database is not standby", + ))); + } else { + break; + } + } + Some(_) => {} + None => return Err(Error::unexpected_message()), + } + } } client.set_socket_config(SocketConfig { diff --git a/tokio-postgres/tests/test/parse.rs b/tokio-postgres/tests/test/parse.rs index 35eeca72b..f041e4445 100644 --- a/tokio-postgres/tests/test/parse.rs +++ b/tokio-postgres/tests/test/parse.rs @@ -42,6 +42,30 @@ fn settings() { .keepalives_idle(Duration::from_secs(30)) .target_session_attrs(TargetSessionAttrs::ReadOnly), ); + check( + "connect_timeout=3 keepalives=0 keepalives_idle=30 target_session_attrs=primary", + Config::new() + .connect_timeout(Duration::from_secs(3)) + .keepalives(false) + .keepalives_idle(Duration::from_secs(30)) + .target_session_attrs(TargetSessionAttrs::Primary), + ); + check( + "connect_timeout=3 keepalives=0 keepalives_idle=30 target_session_attrs=standby", + Config::new() + .connect_timeout(Duration::from_secs(3)) + .keepalives(false) + .keepalives_idle(Duration::from_secs(30)) + .target_session_attrs(TargetSessionAttrs::Standby), + ); + check( + "connect_timeout=3 keepalives=0 keepalives_idle=30 target_session_attrs=prefer-standby", + Config::new() + .connect_timeout(Duration::from_secs(3)) + .keepalives(false) + .keepalives_idle(Duration::from_secs(30)) + .target_session_attrs(TargetSessionAttrs::PreferStandby), + ); check( "sslnegotiation=direct", Config::new().ssl_negotiation(SslNegotiation::Direct), diff --git a/tokio-postgres/tests/test/runtime.rs b/tokio-postgres/tests/test/runtime.rs index 86c1f0701..d188f38a4 100644 --- a/tokio-postgres/tests/test/runtime.rs +++ b/tokio-postgres/tests/test/runtime.rs @@ -49,9 +49,51 @@ async fn wrong_port_count() { .unwrap(); } +#[tokio::test] +async fn target_session_attrs_primary_ok() { + smoke_test("host=localhost,localhost port=5434,5433 user=postgres target_session_attrs=primary").await; +} + +#[tokio::test] +async fn target_session_attrs_standby_ok() { + smoke_test("host=localhost,localhost port=5433,5434 user=postgres target_session_attrs=standby").await; +} + +#[tokio::test] +async fn target_session_attrs_prefer_standby_ok() { + smoke_test("host=localhost,localhost port=5433,5434 user=postgres target_session_attrs=prefer-standby").await; +} + #[tokio::test] async fn target_session_attrs_ok() { - smoke_test("host=localhost port=5433 user=postgres target_session_attrs=read-write").await; + smoke_test("host=localhost,localhost port=5434,5433 user=postgres target_session_attrs=read-write").await; +} + +#[tokio::test] +async fn target_session_attrs_read_only_ok() { + smoke_test("host=localhost,localhost port=5433,5434 user=postgres target_session_attrs=read-only").await; +} + +#[tokio::test] +async fn target_session_attrs_prefer_standby_err() { + tokio_postgres::connect( + "host=localhost port=5433 user=postgres target_session_attrs=prefer-standby + options='-c default_transaction_read_only=on'", + NoTls, + ) + .await + .err(); +} + +#[tokio::test] +async fn target_session_attrs_primary_err() { + tokio_postgres::connect( + "host=localhost port=5433 user=postgres target_session_attrs=primary + options='-c default_transaction_read_only=on'", + NoTls, + ) + .await + .err(); } #[tokio::test]