Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add external oidc tests #855

Merged
merged 5 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions e2e/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ export const routes = {
users: '/admin/users',
openid: '/admin/openid',
overview: '/admin/overview',
settings: '/admin/settings',
},
authorize: '/api/v1/oauth/authorize',
};
Expand Down
99 changes: 99 additions & 0 deletions e2e/tests/externalopenid.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import { expect, test } from '@playwright/test';

import { defaultUserAdmin, routes, testsConfig, testUserTemplate } from '../config';
import { NetworkForm, OpenIdClient, User } from '../types';
import { apiCreateUser } from '../utils/api/users';
import { loginBasic } from '../utils/controllers/login';
import { logout } from '../utils/controllers/logout';
import { copyOpenIdClientIdAndSecret } from '../utils/controllers/openid/copyClientId';
import { CreateExternalProvider } from '../utils/controllers/openid/createExternalProvider';
import { CreateOpenIdClient } from '../utils/controllers/openid/createOpenIdClient';
import { createNetwork } from '../utils/controllers/vpn/createNetwork';
import { dockerDown, dockerRestart } from '../utils/docker';
import { waitForBase } from '../utils/waitForBase';
import { waitForPromise } from '../utils/waitForPromise';
import { waitForRoute } from '../utils/waitForRoute';

test.describe('External OIDC.', () => {
const testUser: User = { ...testUserTemplate, username: 'test' };

const client: OpenIdClient = {
name: 'test 01',
redirectURL: [
'http://localhost:8000/auth/callback',
'http://localhost:8080/openid/callback',
],
scopes: ['openid', 'profile', 'email'],
};

const testNetwork: NetworkForm = {
name: 'test network',
address: '10.10.10.1/24',
endpoint: '127.0.0.1',
port: '5055',
};

test.beforeEach(async ({ browser }) => {
dockerRestart();
await CreateOpenIdClient(browser, client);
[client.clientID, client.clientSecret] = await copyOpenIdClientIdAndSecret(
browser,
client.name
);
const context = await browser.newContext();
const page = await context.newPage();
await CreateExternalProvider(browser, client);
await loginBasic(page, defaultUserAdmin);
await apiCreateUser(page, testUser);
await logout(page);
await createNetwork(browser, testNetwork);
context.close();
});

test.afterAll(() => {
dockerDown();
});

test('Login through external oidc.', async ({ page }) => {
expect(client.clientID).toBeDefined();
expect(client.clientSecret).toBeDefined();
await waitForBase(page);
const oidcLoginButton = await page.getByTestId('login-oidc');
expect(oidcLoginButton).not.toBeNull();
expect(await oidcLoginButton.textContent()).toBe(`Sign in with ${client.name}`);
await oidcLoginButton.click();
await page.getByTestId('login-form-username').fill(testUser.username);
await page.getByTestId('login-form-password').fill(testUser.password);
await page.getByTestId('login-form-submit').click();
await page.getByTestId('openid-allow').click();
await waitForRoute(page, routes.me);
const authorizedApps = await page
.getByTestId('authorized-apps')
.locator('div')
.textContent();
expect(authorizedApps).toContain(client.name);
});

test('Complete enrollment through external OIDC', async ({ page }) => {
await waitForBase(page);
await page.goto(testsConfig.ENROLLMENT_URL);
await waitForPromise(2000);
await page.getByTestId('select-enrollment').click();
await page.getByTestId('login-oidc').click();
await page.getByTestId('login-form-username').fill(testUser.username);
await page.getByTestId('login-form-password').fill(testUser.password);
await page.getByTestId('login-form-submit').click();
await page.getByTestId('openid-allow').click();
const instanceUrlBox = page
.locator('div')
.filter({ hasText: /^Instance URL$/ })
.getByRole('textbox');

expect(await instanceUrlBox.inputValue()).toBe('http://localhost:8080/');
const instanceTokenBox = page
.locator('div')
.filter({ hasText: /^Token$/ })
.getByRole('textbox');
expect((await instanceTokenBox.inputValue()).length).toBeGreaterThan(1);
});
});
2 changes: 1 addition & 1 deletion e2e/tests/openid.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ test.describe('Authorize OpenID client.', () => {

const client: OpenIdClient = {
name: 'test 01',
redirectURL: 'https://oidcdebugger.com/debug',
redirectURL: ['https://oidcdebugger.com/debug'],
scopes: ['openid'],
};

Expand Down
3 changes: 2 additions & 1 deletion e2e/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ export type User = {
export type OpenIdClient = {
name: string;
clientID?: string;
redirectURL: string;
clientSecret?: string;
redirectURL: string[];
scopes: OpenIdScope[];
};

Expand Down
23 changes: 23 additions & 0 deletions e2e/utils/controllers/openid/copyClientId.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,26 @@ export const copyOpenIdClientId = async (browser: Browser, clientId: number) =>
const id = await getPageClipboard(page);
return id;
};

export const copyOpenIdClientIdAndSecret = async (
browser: Browser,
clientName: string
) => {
const context = await browser.newContext();
const page = await context.newPage();
await waitForBase(page);
await loginBasic(page, defaultUserAdmin);
await page.goto(routes.base + routes.admin.openid, { waitUntil: 'networkidle' });
await page
.locator('div')
.filter({
hasText: new RegExp(`^${clientName}$`),
})
.click();
await page.getByTestId('copy-client-id').click();
const id = await getPageClipboard(page);
await page.locator('.variant-copy').nth(1).click();
const secret = await getPageClipboard(page);
await context.close();
return [id, secret];
};
22 changes: 22 additions & 0 deletions e2e/utils/controllers/openid/createExternalProvider.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import { Browser } from 'playwright';

import { defaultUserAdmin, routes } from '../../../config';
import { OpenIdClient } from '../../../types';
import { waitForBase } from '../../waitForBase';
import { loginBasic } from '../login';

export const CreateExternalProvider = async (browser: Browser, client: OpenIdClient) => {
const context = await browser.newContext();
const page = await context.newPage();
await waitForBase(page);
await loginBasic(page, defaultUserAdmin);
await page.goto(routes.base + routes.admin.settings, { waitUntil: 'networkidle' });
await page.getByRole('button', { name: 'OpenID' }).click();
await page.locator('.content-frame').click();
await page.getByRole('button', { name: 'Custom' }).click();
await page.getByTestId('field-base_url').fill('http://localhost:8000/');
await page.getByTestId('field-client_id').fill(client.clientID || '');
await page.getByTestId('field-client_secret').fill(client.clientSecret || '');
await page.getByTestId('field-display_name').fill(client.name);
await page.getByRole('button', { name: 'Save changes' }).click();
};
11 changes: 10 additions & 1 deletion e2e/utils/controllers/openid/createOpenIdClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,16 @@ export const CreateOpenIdClient = async (browser: Browser, client: OpenIdClient)
await modalElement.waitFor({ state: 'visible' });
const modalForm = modalElement.locator('form');
await modalForm.getByTestId('field-name').type(client.name);
await modalForm.getByTestId('field-redirect_uri.0.url').type(client.redirectURL);
const urls = client.redirectURL.length;
for (let i = 0; i < urls; i++) {
const isLast = i === urls - 1;
await modalForm
.getByTestId(`field-redirect_uri.${i}.url`)
.fill(client.redirectURL[i]);
if (!isLast) {
await modalForm.locator('button:has-text("Add URL")').click();
}
}
for (const scope of client.scopes) {
await modalForm.getByTestId(`field-scope-${scope}`).click();
}
Expand Down
7 changes: 2 additions & 5 deletions tests/common/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,7 @@ pub struct TestClient {
#[allow(dead_code)]
impl TestClient {
#[must_use]
pub async fn new(app: Router) -> Self {
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("Could not bind ephemeral socket");
pub async fn new(app: Router, listener: TcpListener) -> Self {
let port = listener.local_addr().unwrap().port();

tokio::spawn(async move {
Expand Down Expand Up @@ -58,7 +55,7 @@ impl TestClient {
///
/// this is useful when trying to check if Location headers in responses
/// are generated correctly as Location contains an absolute URL
fn base_url(&self) -> String {
pub fn base_url(&self) -> String {
let mut s = String::from("http://localhost:");
s.push_str(&self.port.to_string());
s
Expand Down
62 changes: 51 additions & 11 deletions tests/common/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
pub(crate) mod client;

use std::sync::{Arc, Mutex};
use std::{
str::FromStr,
sync::{Arc, Mutex},
};

use defguard::{
auth::failed_login::FailedLoginMap,
Expand All @@ -13,10 +16,11 @@ use defguard::{
mail::Mail,
SERVER_CONFIG,
};
use reqwest::{header::HeaderName, StatusCode};
use reqwest::{header::HeaderName, StatusCode, Url};
use secrecy::ExposeSecret;
use serde_json::json;
use sqlx::{postgres::PgConnectOptions, query, types::Uuid, PgPool};
use tokio::net::TcpListener;
use tokio::sync::{
broadcast::{self, Receiver},
mpsc::{unbounded_channel, UnboundedReceiver},
Expand All @@ -31,9 +35,17 @@ pub const X_FORWARDED_FOR: HeaderName = HeaderName::from_static("x-forwarded-for
#[allow(dead_code, clippy::declare_interior_mutable_const)]
pub const X_FORWARDED_URI: HeaderName = HeaderName::from_static("x-forwarded-uri");

pub async fn init_test_db() -> (PgPool, DefGuardConfig) {
let config = DefGuardConfig::new_test_config();
/// Allows overriding the default DefGuard URL for tests, as during the tests, the server has a random port, making the URL unpredictable beforehand.
// TODO: Allow customizing the whole config, not just the URL
pub fn init_config(custom_defguard_url: Option<&str>) -> DefGuardConfig {
let url = custom_defguard_url.unwrap_or("http://localhost:8000");
let mut config = DefGuardConfig::new_test_config();
config.url = Url::from_str(url).unwrap();
let _ = SERVER_CONFIG.set(config.clone());
config
}

pub async fn init_test_db(config: &DefGuardConfig) -> PgPool {
let opts = PgConnectOptions::new()
.host(&config.database_host)
.port(config.database_port)
Expand All @@ -57,9 +69,9 @@ pub async fn init_test_db() -> (PgPool, DefGuardConfig) {
)
.await;

initialize_users(&pool, &config).await;
initialize_users(&pool, config).await;

(pool, config)
pool
}

async fn initialize_users(pool: &PgPool, config: &DefGuardConfig) {
Expand Down Expand Up @@ -112,7 +124,11 @@ impl ClientState {
}
}

pub async fn make_base_client(pool: PgPool, config: DefGuardConfig) -> (TestClient, ClientState) {
pub async fn make_base_client(
pool: PgPool,
config: DefGuardConfig,
listener: TcpListener,
) -> (TestClient, ClientState) {
let (tx, rx) = unbounded_channel::<AppEvent>();
let worker_state = Arc::new(Mutex::new(WorkerState::new(tx.clone())));
let (wg_tx, wg_rx) = broadcast::channel::<GatewayEvent>(16);
Expand All @@ -124,7 +140,7 @@ pub async fn make_base_client(pool: PgPool, config: DefGuardConfig) -> (TestClie

let license = License::new(
"test_customer".to_string(),
true,
false,
// Some(Utc.with_ymd_and_hms(2030, 1, 1, 0, 0, 0).unwrap()),
// Permanent license
None,
Expand Down Expand Up @@ -167,13 +183,35 @@ pub async fn make_base_client(pool: PgPool, config: DefGuardConfig) -> (TestClie
failed_logins,
);

(TestClient::new(webapp).await, client_state)
(TestClient::new(webapp, listener).await, client_state)
}

/// Make an instance url based on the listener
fn get_test_url(listener: &TcpListener) -> String {
let port = listener.local_addr().unwrap().port();
format!("http://localhost:{}", port)
}

#[allow(dead_code)]
pub async fn make_test_client() -> (TestClient, ClientState) {
let (pool, config) = init_test_db().await;
make_base_client(pool, config).await
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("Could not bind ephemeral socket");
let config = init_config(None);
let pool = init_test_db(&config).await;
make_base_client(pool, config, listener).await
}

/// Makes a test client with a DEFGUARD_URL set to the random url of the listener.
/// This is useful when the instance's url real url needs to match the one set in the ENV variable.
#[allow(dead_code)]
pub async fn make_test_client_with_real_url() -> (TestClient, ClientState) {
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("Could not bind ephemeral socket");
let config = init_config(Some(&get_test_url(&listener)));
let pool = init_test_db(&config).await;
make_base_client(pool, config, listener).await
}

#[allow(dead_code)]
Expand All @@ -183,6 +221,8 @@ pub async fn fetch_user_details(client: &TestClient, username: &str) -> UserDeta
response.json().await
}

/// Exceeds enterprise free version limits by creating more than 1 network
#[allow(dead_code)]
pub async fn exceed_enterprise_limits(client: &TestClient) {
let auth = Auth::new("admin", "pass123");
client.post("/api/v1/auth").json(&auth).send().await;
Expand Down
13 changes: 10 additions & 3 deletions tests/openid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::str::FromStr;

use axum::http::header::ToStrError;
use claims::assert_err;
use common::init_config;
use defguard::{
config::DefGuardConfig,
db::{
Expand All @@ -26,6 +27,7 @@ use reqwest::{
use rsa::RsaPrivateKey;
use serde::Deserialize;
use sqlx::PgPool;
use tokio::net::TcpListener;

mod common;
use self::common::{client::TestClient, init_test_db, make_base_client, make_test_client};
Expand All @@ -36,7 +38,10 @@ async fn make_client() -> TestClient {
}

async fn make_client_v2(pool: PgPool, config: DefGuardConfig) -> TestClient {
let (client, _) = make_base_client(pool, config).await;
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("Could not bind ephemeral socket");
let (client, _) = make_base_client(pool, config, listener).await;
client
}

Expand Down Expand Up @@ -402,7 +407,8 @@ static FAKE_REDIRECT_URI: &str = "http://test.server.tnt:12345/";

#[tokio::test]
async fn test_openid_authorization_code() {
let (pool, config) = init_test_db().await;
let config = init_config(None);
let pool = init_test_db(&config).await;

let issuer_url = IssuerUrl::from_url(config.url.clone());
let client = make_client_v2(pool.clone(), config.clone()).await;
Expand Down Expand Up @@ -505,7 +511,8 @@ async fn test_openid_authorization_code() {

#[tokio::test]
async fn test_openid_authorization_code_with_pkce() {
let (pool, mut config) = init_test_db().await;
let mut config = init_config(None);
let pool = init_test_db(&config).await;
let mut rng = rand::thread_rng();
config.openid_signing_key = RsaPrivateKey::new(&mut rng, 2048).ok();

Expand Down
Loading