Skip to content

Commit

Permalink
Fix Management portal auth and resource enhancer
Browse files Browse the repository at this point in the history
  • Loading branch information
mpgxvii committed Sep 25, 2024
1 parent 8d00f90 commit 5923ad3
Show file tree
Hide file tree
Showing 9 changed files with 96 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ data class UserList(val users: List<User>)

data class User(val id: String, val projectId: String, val externalId: String? = null, val status: String)

fun MPSubject.toUser() = User(
fun MPSubject.toUser(projectId: String) = User(
id = checkNotNull(id) { "User must have a login" },
projectId = checkNotNull(projectId) { "User must have a project ID" },
projectId = projectId,
externalId = externalId,
status = status,
)
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,14 @@ package org.radarbase.authorizer.api

import jakarta.ws.rs.core.Context
import org.radarbase.authorizer.doa.entity.RestSourceUser
import org.radarbase.authorizer.service.ProjectService
import org.radarbase.authorizer.service.RadarProjectService
import org.radarbase.authorizer.service.MPClient
import org.radarbase.kotlin.coroutines.forkJoin
import org.radarbase.jersey.auth.AuthService

class RestSourceUserMapper(
@Context private val authService: AuthService,
) {
private val projectService: ProjectService = ProjectService(MPClient(), authService)
private val projectService: RadarProjectService = RadarProjectService()

suspend fun fromEntity(user: RestSourceUser): RestSourceUserDTO {
val mpUser = user.projectId?.let { p ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,10 @@ import org.radarbase.authorizer.service.RegistrationService
import org.radarbase.authorizer.service.RestSourceAuthorizationService
import org.radarbase.authorizer.service.RestSourceClientService
import org.radarbase.authorizer.service.RestSourceUserService
import org.radarbase.authorizer.service.RadarProjectService
import org.radarbase.jersey.enhancer.JerseyResourceEnhancer
import org.radarbase.jersey.filter.Filters
import org.radarbase.jersey.service.ProjectService

class AuthorizerResourceEnhancer(
private val config: AuthorizerConfig,
Expand Down Expand Up @@ -116,5 +118,9 @@ class AuthorizerResourceEnhancer(
.to(RestSourceAuthorizationService::class.java)
.named(OURA_AUTH)
.`in`(Singleton::class.java)

bind(RadarProjectService::class.java)
.to(ProjectService::class.java)
.`in`(Singleton::class.java)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@ class ManagementPortalEnhancerFactory(private val config: AuthorizerConfig) : En
val authConfig = AuthConfig(
managementPortal = MPConfig(
url = config.auth.managementPortalUrl.trimEnd('/'),
clientId = config.auth.clientId,
clientSecret = config.auth.clientSecret,
syncProjectsIntervalMin = config.service.syncProjectsIntervalMin,
syncParticipantsIntervalMin = config.service.syncParticipantsIntervalMin,
),
Expand All @@ -52,6 +50,7 @@ class ManagementPortalEnhancerFactory(private val config: AuthorizerConfig) : En
Enhancers.health,
HibernateResourceEnhancer(dbConfig),
Enhancers.managementPortal(authConfig),
Enhancers.ecdsa,
JedisResourceEnhancer(),
Enhancers.exception,
AuthorizerResourceEnhancer(config),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.radarbase.authorizer.api.UserList
import org.radarbase.authorizer.api.toProject
import org.radarbase.authorizer.api.toUser
import org.radarbase.authorizer.service.MPClient
import org.radarbase.authorizer.service.ProjectService
import org.radarbase.authorizer.service.RadarProjectService
import org.radarbase.jersey.auth.AuthService
import org.radarbase.jersey.auth.Authenticated
import org.radarbase.jersey.auth.NeedsPermission
Expand All @@ -46,9 +46,8 @@ import org.radarbase.jersey.service.AsyncCoroutineService
@Singleton
class ProjectResource(
@Context private val asyncService: AsyncCoroutineService,
@Context private val authService: AuthService,
) {
private val projectService: ProjectService = ProjectService(MPClient(), authService)
private val projectService: RadarProjectService = RadarProjectService()

@GET
@NeedsPermission(Permission.PROJECT_READ)
Expand All @@ -72,7 +71,7 @@ class ProjectResource(
UserList(
projectService
.projectSubjects(projectId)
.map { it.toUser() },
.map { it.toUser(projectId) },
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import org.radarbase.jersey.exception.HttpConflictException
import org.radarbase.jersey.service.AsyncCoroutineService
import java.net.URI
import org.radarbase.authorizer.service.MPClient
import org.radarbase.authorizer.service.ProjectService
import org.radarbase.authorizer.service.RadarProjectService

@Path("registrations")
@Produces(MediaType.APPLICATION_JSON)
Expand All @@ -48,10 +48,10 @@ class RegistrationResource(
@Context private val authorizationService: RestSourceAuthorizationService,
@Context private val userRepository: RestSourceUserRepository,
@Context private val registrationService: RegistrationService,
@Context private val authService: AuthService,
@Context private val asyncService: AsyncCoroutineService,
@Context private val authService: AuthService,
) {
private val projectService: ProjectService = ProjectService(MPClient(), authService)
private val projectService: RadarProjectService = RadarProjectService()

@POST
@Authenticated
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ import org.radarbase.jersey.exception.HttpBadRequestException
import org.radarbase.jersey.service.AsyncCoroutineService
import java.net.URI
import org.radarbase.authorizer.service.MPClient
import org.radarbase.authorizer.service.ProjectService
import org.radarbase.authorizer.service.RadarProjectService

@Path("users")
@Produces(MediaType.APPLICATION_JSON)
Expand All @@ -69,7 +69,7 @@ class RestSourceUserResource(
@Context private val asyncService: AsyncCoroutineService,
@Context private val authService: AuthService,
) {
private val projectService: ProjectService = ProjectService(MPClient(), authService)
private val projectService: RadarProjectService = RadarProjectService()

@GET
@NeedsPermission(Permission.SUBJECT_READ)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import io.ktor.client.request.forms.*
import io.ktor.client.call.*
import org.radarbase.authorizer.config.AuthorizerConfig

@Singleton
class MPClient {
private val config: AuthorizerConfig = AuthorizerConfig()
private val logger = LoggerFactory.getLogger(MPClient::class.java)
Expand Down Expand Up @@ -53,6 +54,22 @@ class MPClient {
return tokenResponse.access_token
}

suspend fun requestOrganizations(page: Int = 0, size: Int = Int.MAX_VALUE): List<MPOrganization> {
val accessToken = getAccessToken()
val response: HttpResponse = httpClient.get("${config.auth.managementPortalUrl}/api/organizations") {
headers {
append(HttpHeaders.Authorization, "Bearer $accessToken")
}
}

if (!response.status.isSuccess()) {
logger.error("Failed to fetch projects: ${response.status}")
throw RuntimeException("Failed to fetch projects")
}

return response.body()
}

suspend fun requestProjects(page: Int = 0, size: Int = Int.MAX_VALUE): List<MPProject> {
val accessToken = getAccessToken()
val response: HttpResponse = httpClient.get("${config.auth.managementPortalUrl}/api/projects") {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
package org.radarbase.authorizer.service

import jakarta.inject.Singleton
import io.ktor.http.HttpStatusCode
import org.radarbase.jersey.auth.AuthService
import org.radarbase.kotlin.coroutines.CacheConfig
import org.radarbase.kotlin.coroutines.CachedMap
import org.radarbase.management.client.MPProject
import org.radarbase.management.client.MPSubject
import org.radarbase.management.client.MPOrganization
import org.radarbase.management.client.HttpStatusException
import org.slf4j.LoggerFactory
import java.time.Duration
import kotlin.time.Duration.Companion.minutes
Expand All @@ -17,12 +19,14 @@ import org.radarbase.auth.authorization.Permission
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.ConcurrentMap
import org.radarbase.jersey.exception.HttpNotFoundException
import org.radarbase.jersey.service.ProjectService

class RadarProjectService(
): ProjectService {
private val mpClient: MPClient = MPClient()

class ProjectService(
private val mpClient: MPClient,
private val authService: AuthService,
) {
private val projects: CachedMap<String, MPProject>
private val organizations: CachedMap<String, MPOrganization>
private val participants: ConcurrentMap<String, CachedMap<String, MPSubject>> = ConcurrentHashMap()

private val projectCacheConfig = CacheConfig(
Expand All @@ -47,34 +51,74 @@ class ProjectService(
throw RuntimeException("Unable to fetch projects", e)
}
}

organizations = CachedMap(cacheConfig) {
try {
mpClient.requestOrganizations()
.associateBy { it.id }
.also { logger.debug("Fetched organizations {}", it) }
} catch (ex: HttpStatusException) {
if (ex.code == HttpStatusCode.NotFound) {
logger.warn("Target ManagementPortal does not support organizations. Using default organization main.")
mapOf("main" to defaultOrganization)
} else {
throw ex
}
}
}
}

suspend fun getProjects(): List<MPProject> = projects.get().values.toList()

suspend fun userProjects(permission: Permission): List<MPProject> {
return projects.get()
.values
.filter {
authService.hasPermission(
permission,
EntityDetails(
organization = it.organization?.id,
project = it.id,
),
)
}
.values.toList()
// .filter {
// authService.hasPermission(
// permission,
// EntityDetails(
// organization = it.organization?.id,
// project = it.id,
// ),
// )
// }
}

suspend fun ensureProject(projectId: String) {
override suspend fun ensureSubject(projectId: String, userId: String) {
ensureProject(projectId)
if (!projectUserCache(projectId).contains(userId)) {
throw HttpNotFoundException("user_not_found", "User $userId not found in project $projectId of ManagementPortal.")
}
}

override suspend fun ensureOrganization(organizationId: String) {
if (!organizations.contains(organizationId)) {
throw HttpNotFoundException("organization_not_found", "Organization $organizationId not found in Management Portal.")
}
}

override suspend fun ensureProject(projectId: String) {
if (!projects.contains(projectId)) {
throw HttpNotFoundException("project_not_found", "Project $projectId not found in Management Portal.")
}
}

override suspend fun projectOrganization(projectId: String): String =
projects.get(projectId)?.organization?.id
?: throw HttpNotFoundException("project_not_found", "Project $projectId not found in Management Portal.")

suspend fun project(projectId: String): MPProject = projects.get(projectId)
?: throw HttpNotFoundException("project_not_found", "Project $projectId not found in Management Portal.")

suspend fun projectSubjects(projectId: String): List<MPSubject> = projectUserCache(projectId).get().values.toList()
override suspend fun listProjects(organizationId: String): List<String> = projects.get().asSequence()
.filter { it.value.organization?.id == organizationId }
.mapTo(ArrayList()) { it.key }

suspend fun projectSubjects(projectId: String): List<MPSubject> {
logger.info("Fetching subjects for project $projectId")
logger.info(projectUserCache(projectId).get().values.toList().toString())
return projectUserCache(projectId).get().values.toList()
}

private suspend fun projectUserCache(projectId: String) = participants.computeIfAbsent(projectId) {
CachedMap(projectCacheConfig) {
Expand Down

0 comments on commit 5923ad3

Please sign in to comment.