Skip to content

Commit

Permalink
refactor: update KeycloakTokenValidator to support multiple realms an…
Browse files Browse the repository at this point in the history
…d Connector Builder Server (#13497)
  • Loading branch information
pmossman committed Aug 13, 2024
1 parent 00bab7d commit d3eb6f9
Show file tree
Hide file tree
Showing 10 changed files with 195 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.airbyte.commons.json.Jsons;
import io.airbyte.config.AuthProvider;
import java.nio.charset.StandardCharsets;
import java.util.Base64;
import java.util.HashMap;
import java.util.Map;
import lombok.extern.slf4j.Slf4j;
Expand Down Expand Up @@ -40,6 +43,13 @@ public static String getJwtPayloadToken(final String jwtToken) {
return jwtPayload;
}

public static Map<String, Object> tokenToAttributes(final String jwtToken) {
final String rawJwtPayload = getJwtPayloadToken(jwtToken);
final String jwtPayloadDecoded = new String(Base64.getUrlDecoder().decode(rawJwtPayload), StandardCharsets.UTF_8);
final JsonNode jwtPayloadNode = Jsons.deserialize(jwtPayloadDecoded);
return convertJwtPayloadToUserAttributes(jwtPayloadNode);
}

/**
* Going through JWT payload part and extract fields backend cares about into a map.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ class AirbyteKeycloakConfiguration {
var password: String = ""
var resetRealm: Boolean = false

fun getKeycloakUserInfoEndpoint(): String {
fun getKeycloakUserInfoEndpointForRealm(realm: String): String {
val hostWithoutTrailingSlash = if (host.endsWith("/")) host.substring(0, host.length - 1) else host
val basePathWithLeadingSlash = if (basePath.startsWith("/")) basePath else "/$basePath"
val keycloakUserInfoURI = "/protocol/openid-connect/userinfo"
return "$protocol://$hostWithoutTrailingSlash$basePathWithLeadingSlash/realms/$airbyteRealm$keycloakUserInfoURI"
return "$protocol://$hostWithoutTrailingSlash$basePathWithLeadingSlash/realms/$realm$keycloakUserInfoURI"
}

fun getServerUrl(): String = "$protocol://$host$basePath"
Expand Down
2 changes: 2 additions & 0 deletions airbyte-commons-server/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ dependencies {
implementation(libs.bundles.log4j)
implementation(libs.commons.io)
implementation(libs.kotlin.logging)
implementation(libs.reactor.core)
implementation(project(":oss:airbyte-analytics"))
implementation(project(":oss:airbyte-api:connector-builder-api"))
implementation(project(":oss:airbyte-api:problems-api"))
Expand Down Expand Up @@ -81,6 +82,7 @@ dependencies {
testImplementation(libs.bundles.micronaut.test)
testImplementation(libs.micronaut.http)
testImplementation(libs.mockk)
testImplementation(libs.reactor.test)

testRuntimeOnly(libs.junit.jupiter.engine)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,16 @@
* Copyright (c) 2020-2024 Airbyte, Inc., all rights reserved.
*/

package io.airbyte.server.pro;
package io.airbyte.commons.server.authorization;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.airbyte.commons.auth.AuthRole;
import io.airbyte.commons.auth.RequiresAuthMode;
import io.airbyte.commons.auth.config.AirbyteKeycloakConfiguration;
import io.airbyte.commons.auth.config.AuthMode;
import io.airbyte.commons.auth.support.JwtTokenParser;
import io.airbyte.commons.json.Jsons;
import io.airbyte.commons.license.annotation.RequiresAirbyteProEnabled;
import io.airbyte.commons.server.support.RbacRoleHelper;
import io.micrometer.common.util.StringUtils;
import io.micronaut.http.HttpRequest;
import io.micronaut.security.authentication.Authentication;
Expand All @@ -22,8 +21,7 @@
import jakarta.inject.Singleton;
import java.nio.charset.StandardCharsets;
import java.util.Base64;
import java.util.Collection;
import java.util.HashSet;
import java.util.Map;
import lombok.extern.slf4j.Slf4j;
import okhttp3.OkHttpClient;
import okhttp3.Request;
Expand All @@ -32,26 +30,26 @@
import reactor.core.publisher.Mono;

/**
* Token Validator for Airbyte Pro. Performs an online validation of the token against the Keycloak
* server.
* Token Validator for Airbyte Cloud and Enterprise. Performs an online validation of the token
* against the Keycloak server.
*/
@Slf4j
@Singleton
@RequiresAirbyteProEnabled
@RequiresAuthMode(AuthMode.OIDC)
@SuppressWarnings({"PMD.PreserveStackTrace", "PMD.UseTryWithResources", "PMD.UnusedFormalParameter", "PMD.UnusedPrivateMethod",
"PMD.ExceptionAsFlowControl"})
public class KeycloakTokenValidator implements TokenValidator<HttpRequest<?>> {

private final OkHttpClient client;
private final AirbyteKeycloakConfiguration keycloakConfiguration;
private final RbacRoleHelper rbacRoleHelper;
private final TokenRoleResolver tokenRoleResolver;

public KeycloakTokenValidator(@Named("keycloakTokenValidatorHttpClient") final OkHttpClient okHttpClient,
final AirbyteKeycloakConfiguration keycloakConfiguration,
final RbacRoleHelper rbacRoleHelper) {
final TokenRoleResolver tokenRoleResolver) {
this.client = okHttpClient;
this.keycloakConfiguration = keycloakConfiguration;
this.rbacRoleHelper = rbacRoleHelper;
this.tokenRoleResolver = tokenRoleResolver;
}

@Override
Expand All @@ -70,7 +68,6 @@ public Publisher<Authentication> validateToken(final String token, final HttpReq

private Authentication getAuthentication(final String token, final HttpRequest<?> request) {
final String payload = JwtTokenParser.getJwtPayloadToken(token);
final Collection<String> roles = new HashSet<>();

try {
final String jwtPayloadString = new String(Base64.getUrlDecoder().decode(payload), StandardCharsets.UTF_8);
Expand All @@ -81,11 +78,7 @@ private Authentication getAuthentication(final String token, final HttpRequest<?
log.debug("Performing authentication for auth user '{}'...", authUserId);

if (StringUtils.isNotBlank(authUserId)) {
log.debug("Successfully authenticated auth user '{}'.", authUserId);
roles.add(AuthRole.AUTHENTICATED_USER.toString());

log.debug("Fetching roles for auth user '{}'...", authUserId);
roles.addAll(rbacRoleHelper.getRbacRoles(authUserId, request));
final var roles = tokenRoleResolver.resolveRoles(authUserId, request);

log.debug("Authenticating user '{}' with roles {}...", authUserId, roles);
final var userAttributeMap = JwtTokenParser.convertJwtPayloadToUserAttributes(jwtPayload);
Expand All @@ -100,10 +93,19 @@ private Authentication getAuthentication(final String token, final HttpRequest<?
}

private Mono<Boolean> validateTokenWithKeycloak(final String token) {
final String realm;
try {
final Map<String, Object> jwtAttributes = JwtTokenParser.tokenToAttributes(token);
realm = (String) jwtAttributes.get(JwtTokenParser.JWT_SSO_REALM);
log.debug("Extracted realm {}", realm);
} catch (final Exception e) {
log.error("Failed to parse realm from JWT token: {}", token, e);
return Mono.just(false);
}
final okhttp3.Request request = new Request.Builder()
.addHeader(org.apache.http.HttpHeaders.CONTENT_TYPE, "application/json")
.addHeader(org.apache.http.HttpHeaders.AUTHORIZATION, "Bearer " + token)
.url(keycloakConfiguration.getKeycloakUserInfoEndpoint())
.url(keycloakConfiguration.getKeycloakUserInfoEndpointForRealm(realm))
.get()
.build();

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package io.airbyte.server.config
package io.airbyte.commons.server.authorization

import io.micronaut.context.annotation.Factory
import jakarta.inject.Named
Expand All @@ -14,7 +14,5 @@ class HttpClientFactory {
*/
@Singleton
@Named("keycloakTokenValidatorHttpClient")
fun okHttpClient(): OkHttpClient {
return OkHttpClient()
}
fun okHttpClient(): OkHttpClient = OkHttpClient()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package io.airbyte.commons.server.authorization

import io.airbyte.commons.auth.AuthRole
import io.airbyte.commons.server.support.RbacRoleHelper
import io.github.oshai.kotlinlogging.KotlinLogging
import io.micronaut.core.annotation.Nullable
import io.micronaut.http.HttpRequest
import jakarta.inject.Singleton

private val logger = KotlinLogging.logger {}

interface TokenRoleResolver {
fun resolveRoles(
@Nullable authUserId: String?,
httpRequest: HttpRequest<*>,
): Set<String>
}

@Singleton
class RbacTokenRoleResolver(
private val rbacRoleHelper: RbacRoleHelper,
) : TokenRoleResolver {
override fun resolveRoles(
@Nullable authUserId: String?,
httpRequest: HttpRequest<*>,
): Set<String> {
logger.debug { "Resolving roles for authUserId $authUserId" }

if (authUserId.isNullOrBlank()) {
logger.debug { "Provided authUserId is null or blank, returning empty role set" }
return setOf()
}

return mutableSetOf(AuthRole.AUTHENTICATED_USER.name).apply {
addAll(rbacRoleHelper.getRbacRoles(authUserId, httpRequest))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,14 @@
* Copyright (c) 2020-2024 Airbyte, Inc., all rights reserved.
*/

package io.airbyte.server;
package io.airbyte.commons.server.authorization;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

import io.airbyte.commons.auth.config.AirbyteKeycloakConfiguration;
import io.airbyte.commons.server.support.RbacRoleHelper;
import io.airbyte.server.pro.KeycloakTokenValidator;
import io.micronaut.http.HttpHeaders;
import io.micronaut.http.HttpRequest;
import io.micronaut.http.netty.NettyHttpHeaders;
Expand Down Expand Up @@ -60,7 +58,7 @@ class KeycloakTokenValidatorTest {
private KeycloakTokenValidator keycloakTokenValidator;
private OkHttpClient httpClient;
private AirbyteKeycloakConfiguration keycloakConfiguration;
private RbacRoleHelper rbacRoleHelper;
private TokenRoleResolver tokenRoleResolver;

@BeforeEach
void setUp() {
Expand All @@ -71,10 +69,10 @@ void setUp() {
httpClient = mock(OkHttpClient.class);

keycloakConfiguration = mock(AirbyteKeycloakConfiguration.class);
when(keycloakConfiguration.getKeycloakUserInfoEndpoint()).thenReturn(LOCALHOST + URI_PATH);
rbacRoleHelper = mock(RbacRoleHelper.class);
when(keycloakConfiguration.getKeycloakUserInfoEndpointForRealm(any())).thenReturn(LOCALHOST + URI_PATH);
tokenRoleResolver = mock(TokenRoleResolver.class);

keycloakTokenValidator = new KeycloakTokenValidator(httpClient, keycloakConfiguration, rbacRoleHelper);
keycloakTokenValidator = new KeycloakTokenValidator(httpClient, keycloakConfiguration, tokenRoleResolver);
}

@Test
Expand Down Expand Up @@ -107,7 +105,7 @@ void testValidateToken() throws Exception {
final Set<String> mockedRoles =
Set.of("ORGANIZATION_ADMIN", "ORGANIZATION_EDITOR", "ORGANIZATION_READER", "ORGANIZATION_MEMBER", "ADMIN", "EDITOR", "READER");

when(rbacRoleHelper.getRbacRoles(eq(expectedUserId), any(HttpRequest.class)))
when(tokenRoleResolver.resolveRoles(eq(expectedUserId), any(HttpRequest.class)))
.thenReturn(mockedRoles);

StepVerifier.create(responsePublisher)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package io.airbyte.commons.server.authorization

import io.airbyte.commons.auth.AuthRole
import io.airbyte.commons.server.support.RbacRoleHelper
import io.micronaut.http.HttpRequest
import io.mockk.every
import io.mockk.mockk
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test

class RbacTokenRoleResolverTest {
private lateinit var rbacRoleHelper: RbacRoleHelper
private lateinit var rbacTokenRoleResolver: RbacTokenRoleResolver

@BeforeEach
fun setup() {
rbacRoleHelper = mockk()
rbacTokenRoleResolver = RbacTokenRoleResolver(rbacRoleHelper)
}

@Test
fun `test resolveRoles with null authUserId`() {
val roles = rbacTokenRoleResolver.resolveRoles(null, HttpRequest.GET<Any>("/"))
assertEquals(setOf<String>(), roles)
}

@Test
fun `test resolveRoles with blank authUserId`() {
val roles = rbacTokenRoleResolver.resolveRoles("", HttpRequest.GET<Any>("/"))
assertEquals(setOf<String>(), roles)
}

@Test
fun `test resolveRoles with valid authUserId`() {
val authUserId = "test-user"
val expectedRoles = setOf("ORGANIZATION_ADMIN", "WORKSPACE_EDITOR")
every { rbacRoleHelper.getRbacRoles(authUserId, any(HttpRequest::class)) } returns expectedRoles

val roles = rbacTokenRoleResolver.resolveRoles(authUserId, HttpRequest.GET<Any>("/"))
assertEquals(setOf(AuthRole.AUTHENTICATED_USER.name).plus(expectedRoles), roles)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
@file:Suppress("ktlint:standard:package-name")

package io.airbyte.connector_builder.authorization

import io.airbyte.commons.auth.AuthRole
import io.airbyte.commons.server.authorization.TokenRoleResolver
import io.github.oshai.kotlinlogging.KotlinLogging
import io.micronaut.context.annotation.Primary
import jakarta.inject.Singleton

private val logger = KotlinLogging.logger {}

/**
* The Connector Builder Server's role resolver does not apply RBAC-specific roles, because they
* are not needed and currently inaccessible in the Connector Builder Server, which is isolated
* from other internal Airbyte applications (like the Config DB). If RBAC roles are needed in the
* future, the Connector Builder Server will need to be updated such that it is able to determine
* the RBAC roles of a user based on the Permissions stored in the Config DB.
*/
@Primary
@Singleton
class ConnectorBuilderTokenRoleResolver : TokenRoleResolver {
override fun resolveRoles(
authUserId: String?,
httpRequest: io.micronaut.http.HttpRequest<*>,
): Set<String> {
if (authUserId.isNullOrBlank()) {
logger.debug { "Provided authUserId is null or blank, returning empty role set" }
return setOf()
}

return setOf(AuthRole.AUTHENTICATED_USER.name)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
@file:Suppress("ktlint:standard:package-name")

package io.airbyte.connector_builder.authorization

import io.airbyte.commons.auth.AuthRole
import io.micronaut.http.HttpRequest
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test

class ConnectorBuilderTokenRoleResolverTest {
private lateinit var resolver: ConnectorBuilderTokenRoleResolver

@BeforeEach
fun setup() {
resolver = ConnectorBuilderTokenRoleResolver()
}

@Test
fun `test resolveRoles with null authUserId`() {
val roles = resolver.resolveRoles(null, HttpRequest.GET<Any>("/"))
assertEquals(setOf<String>(), roles)
}

@Test
fun `test resolveRoles with blank authUserId`() {
val roles = resolver.resolveRoles("", HttpRequest.GET<Any>("/"))
assertEquals(setOf<String>(), roles)
}

@Test
fun `test resolveRoles with valid authUserId`() {
val authUserId = "test-user"
val roles = resolver.resolveRoles(authUserId, HttpRequest.GET<Any>("/"))
assertEquals(setOf(AuthRole.AUTHENTICATED_USER.name), roles)
}
}

0 comments on commit d3eb6f9

Please sign in to comment.