Skip to content

Commit

Permalink
DT-732 AWS X-Ray tracing via OpenTelemetry
Browse files Browse the repository at this point in the history
  • Loading branch information
BenRamchandani committed Jun 13, 2024
1 parent 9a5217e commit 9207645
Show file tree
Hide file tree
Showing 28 changed files with 446 additions and 71 deletions.
7 changes: 7 additions & 0 deletions auth-service/.env.template
Original file line number Diff line number Diff line change
@@ -1,2 +1,9 @@
LDAP_AUTH_SERVICE_USER_PASSWORD=Get from auth-service-ldap-user-password-test in Secrets Manager
# AZ_SSO_CLIENTS_JSON=[{"internalId":"dev","azTenantId":"example","azClientId":"example","azClientSecret":""}]

# Set to enable Telemetry, uploading traces to AWS X-Ray via OTel collector
# Run `docker compose --profile tracing up` after filling this in
# AUTH_TELEMETRY_PREFIX=local
# AWS_ACCESS_KEY_ID=User with AWSXRayDaemonWriteAccess
# AWS_SECRET_ACCESS_KEY=
# AWS_REGION=eu-west-1
12 changes: 12 additions & 0 deletions auth-service/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,18 @@ dependencies {
testImplementation("org.jetbrains.kotlin:kotlin-test-junit:$kotlinVersion")
testImplementation("io.ktor:ktor-client-mock:$ktorVersion")
testImplementation("io.mockk:mockk:1.13.11")

// OpenTelemetry
api(platform("io.opentelemetry.instrumentation:opentelemetry-instrumentation-bom-alpha:2.4.0-alpha"))
implementation("io.opentelemetry:opentelemetry-api")
implementation("io.opentelemetry:opentelemetry-sdk")
implementation("io.opentelemetry:opentelemetry-exporter-otlp")
testImplementation("io.opentelemetry:opentelemetry-exporter-logging")
implementation("io.opentelemetry.instrumentation:opentelemetry-ktor-2.0")
implementation("io.opentelemetry.instrumentation:opentelemetry-jdbc")

implementation("io.opentelemetry.contrib:opentelemetry-aws-xray-propagator:1.36.0-alpha")
implementation("io.opentelemetry.contrib:opentelemetry-aws-xray:1.36.0")
}

// Migrations are run by the application on startup, or on first use of the database in Development mode.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ fun Application.appModule() {
injection.forgotPasswordRateLimitCounter,
)
configureSecurity(injection)
configureMonitoring(injection.meterRegistry)
configureMonitoring(injection.meterRegistry, injection.openTelemetry)
configureSerialization()
configureTemplating(developmentMode)
configureRouting(injection)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@ import io.micrometer.cloudwatch2.CloudWatchMeterRegistry
import io.micrometer.core.instrument.Clock
import io.micrometer.core.instrument.Counter
import io.micrometer.core.instrument.simple.SimpleMeterRegistry
import io.opentelemetry.context.Context
import org.slf4j.Logger
import software.amazon.awssdk.services.cloudwatch.CloudWatchAsyncClient
import uk.gov.communities.delta.auth.config.*
import uk.gov.communities.delta.auth.controllers.external.*
import uk.gov.communities.delta.auth.controllers.internal.*
import uk.gov.communities.delta.auth.plugins.SpanFactory
import uk.gov.communities.delta.auth.plugins.initOpenTelemetry
import uk.gov.communities.delta.auth.repositories.*
import uk.gov.communities.delta.auth.saml.SAMLTokenService
import uk.gov.communities.delta.auth.security.ADLdapLoginService
Expand All @@ -31,6 +34,7 @@ class Injection(
val emailConfig: EmailConfig,
val azureADSSOConfig: AzureADSSOConfig,
val authServiceConfig: AuthServiceConfig,
val tracingConfig: TracingConfig,
) {
companion object {
lateinit var instance: Injection
Expand All @@ -47,6 +51,7 @@ class Injection(
EmailConfig.fromEnv(),
AzureADSSOConfig.fromEnv(),
AuthServiceConfig.fromEnv(),
TracingConfig.fromEnv(),
)
return instance
}
Expand All @@ -60,6 +65,7 @@ class Injection(
clientConfig.log(logger.atInfo())
azureADSSOConfig.log(logger.atInfo())
authServiceConfig.log(logger.atInfo())
tracingConfig.log(logger.atInfo())
}

fun close() {
Expand All @@ -71,11 +77,21 @@ class Injection(
Runtime.getRuntime().addShutdownHook(Thread { close() })
}

val dbPool = DbPool(databaseConfig)
val openTelemetry = initOpenTelemetry(tracingConfig)

private val samlTracer = openTelemetry.getTracer("delta.auth.samlTokenGenerator")
private val ldapTracer = openTelemetry.getTracer("delta.auth.ldap")
private val ldapSpanFactory: SpanFactory = {
ldapTracer.spanBuilder(it).setParent(Context.current())
.setAttribute("peer.service", "ActiveDirectory")
.setAttribute("delta.request-to", "AD-ldap")
}

private val samlTokenService = SAMLTokenService(samlTracer)
val dbPool = DbPool(databaseConfig, openTelemetry)

private val samlTokenService = SAMLTokenService()
private val ldapRepository = LdapRepository(ldapConfig, LdapRepository.ObjectGUIDMode.NEW_JAVA_UUID_STRING)
private val ldapServiceUserBind = LdapServiceUserBind(ldapConfig, ldapRepository)
private val ldapServiceUserBind = LdapServiceUserBind(ldapConfig, ldapRepository, ldapSpanFactory)

val userAuditTrailRepo = UserAuditTrailRepo()
val userAuditService = UserAuditService(userAuditTrailRepo, dbPool)
Expand All @@ -86,7 +102,7 @@ class Injection(

val setPasswordTokenService = SetPasswordTokenService(dbPool, TimeSource.System)
val resetPasswordTokenService = ResetPasswordTokenService(dbPool, TimeSource.System)
val organisationService = OrganisationService(OrganisationService.makeHTTPClient(), deltaConfig)
val organisationService = OrganisationService(OrganisationService.makeHTTPClient(openTelemetry), deltaConfig)

private val userLookupService = UserLookupService(
ldapConfig.deltaUserDnFormat,
Expand Down Expand Up @@ -137,7 +153,7 @@ class Injection(
val ssoLoginStateService = SSOLoginSessionStateService()
val ssoOAuthClientProviderLookupService =
SSOOAuthClientProviderLookupService(azureADSSOConfig, ssoLoginStateService)
val microsoftGraphService = MicrosoftGraphService()
val microsoftGraphService = MicrosoftGraphService(openTelemetry)
val deltaUserDetailsRequestMapper = DeltaUserPermissionsRequestMapper(organisationService, accessGroupsService)
val meterRegistry =
if (authServiceConfig.metricsNamespace.isNullOrEmpty()) SimpleMeterRegistry() else CloudWatchMeterRegistry(
Expand Down Expand Up @@ -167,7 +183,8 @@ class Injection(
val deleteOldDeltaSessionsTask = DeleteOldDeltaSessions(dbPool)
val deleteOldApiTokensTask = DeleteOldApiTokens(dbPool)
val updateUserGuidMapTask = UpdateUserGUIDMap(ldapConfig, dbPool)
val tasks = listOf(deleteOldAuthCodesTask, deleteOldDeltaSessionsTask, deleteOldApiTokensTask, updateUserGuidMapTask)
val tasks =
listOf(deleteOldAuthCodesTask, deleteOldDeltaSessionsTask, deleteOldApiTokensTask, updateUserGuidMapTask)
return tasks.associateBy { it.name }
}

Expand All @@ -176,7 +193,7 @@ class Injection(
fun ldapServiceUserAuthenticationService(): LdapAuthenticationService {
val adLoginService = ADLdapLoginService(
ADLdapLoginService.Configuration(ldapConfig.serviceUserDnFormat),
ldapRepository
ldapRepository, ldapSpanFactory
)
return LdapAuthenticationService(adLoginService, ldapConfig.serviceUserRequiredGroupCn)
}
Expand All @@ -186,7 +203,7 @@ class Injection(
fun externalDeltaLoginController(): DeltaLoginController {
val adLoginService = ADLdapLoginService(
ADLdapLoginService.Configuration(ldapConfig.deltaUserDnFormat),
ldapRepository
ldapRepository, ldapSpanFactory
)
return DeltaLoginController(
clientConfig.oauthClients,
Expand Down Expand Up @@ -251,12 +268,13 @@ class Injection(
fun externalDeltaApiTokenController(): ExternalDeltaApiTokenController {
val adLoginService = ADLdapLoginService(
ADLdapLoginService.Configuration(ldapConfig.deltaUserDnFormat),
ldapRepository
ldapRepository, ldapSpanFactory
)
return ExternalDeltaApiTokenController(deltaApiTokenService, adLoginService)
}

fun internalDeltaApiTokenController() = InternalDeltaApiTokenController(deltaApiTokenService, samlTokenService, userLookupService)
fun internalDeltaApiTokenController() =
InternalDeltaApiTokenController(deltaApiTokenService, samlTokenService, userLookupService)

fun refreshUserInfoController() = RefreshUserInfoController(
userLookupService,
Expand Down Expand Up @@ -377,5 +395,6 @@ class Injection(

fun editRolesController() = EditRolesController(userLookupService, groupService)

fun editOrganisationsController() = EditOrganisationsController(userLookupService, groupService, organisationService)
fun editOrganisationsController() =
EditOrganisationsController(userLookupService, groupService, organisationService)
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import io.ktor.http.*
import io.ktor.server.application.*
import io.ktor.server.auth.*
import io.ktor.server.http.content.*
import io.ktor.server.plugins.cors.routing.*
import io.ktor.server.plugins.ratelimit.*
import io.ktor.server.response.*
import io.ktor.server.routing.*
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package uk.gov.communities.delta.auth.config

import org.slf4j.spi.LoggingEventBuilder

class TracingConfig(val prefix: String?) {
companion object {
fun fromEnv(): TracingConfig {
val prefix = Env.getEnv("AUTH_TELEMETRY_PREFIX")
return TracingConfig(if (prefix.isNullOrEmpty()) null else prefix)
}
}

val enabled = prefix != null
val serviceName = if (prefix != null) "$prefix-auth-service" else null

fun log(logger: LoggingEventBuilder) {
logger
.addKeyValue("TracingEnabled", enabled).addKeyValue("TelemetryServiceName", serviceName)
.log("Tracing config")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,42 @@ import io.ktor.server.plugins.callloging.*
import io.ktor.server.request.*
import io.micrometer.core.instrument.MeterRegistry
import io.micrometer.core.instrument.config.MeterFilter
import io.opentelemetry.api.OpenTelemetry
import io.opentelemetry.api.common.AttributeKey
import io.opentelemetry.api.common.Attributes
import io.opentelemetry.api.trace.Span
import io.opentelemetry.api.trace.SpanBuilder
import io.opentelemetry.api.trace.SpanKind
import io.opentelemetry.api.trace.propagation.W3CTraceContextPropagator
import io.opentelemetry.context.Context
import io.opentelemetry.context.propagation.ContextPropagators
import io.opentelemetry.context.propagation.TextMapPropagator
import io.opentelemetry.contrib.awsxray.AwsXrayIdGenerator
import io.opentelemetry.contrib.awsxray.propagator.AwsXrayPropagator
import io.opentelemetry.exporter.otlp.trace.OtlpGrpcSpanExporter
import io.opentelemetry.instrumentation.ktor.v2_0.server.KtorServerTracing
import io.opentelemetry.sdk.OpenTelemetrySdk
import io.opentelemetry.sdk.resources.Resource
import io.opentelemetry.sdk.trace.SdkTracerProvider
import io.opentelemetry.sdk.trace.data.LinkData
import io.opentelemetry.sdk.trace.export.BatchSpanProcessor
import io.opentelemetry.sdk.trace.samplers.Sampler
import io.opentelemetry.sdk.trace.samplers.SamplingResult
import kotlinx.coroutines.slf4j.MDCContext
import kotlinx.coroutines.withContext
import org.slf4j.MDC
import org.slf4j.event.Level
import uk.gov.communities.delta.auth.config.TracingConfig
import uk.gov.communities.delta.auth.security.CLIENT_HEADER_AUTH_NAME
import uk.gov.communities.delta.auth.security.ClientPrincipal
import uk.gov.communities.delta.auth.security.DELTA_AD_LDAP_SERVICE_USERS_AUTH_NAME
import uk.gov.communities.delta.auth.security.DeltaLdapPrincipal
import uk.gov.communities.delta.auth.services.OAuthSession
import kotlin.collections.set

fun Application.configureMonitoring(meterRegistry: MeterRegistry) {
fun Application.configureMonitoring(meterRegistry: MeterRegistry, openTelemetry: OpenTelemetry) {
// TODO At some point we should replace this, we want structured logs
// and to clear the MDC before each request
install(CallLogging) {
level = Level.INFO
callIdMdc("requestId")
Expand Down Expand Up @@ -48,6 +72,10 @@ fun Application.configureMonitoring(meterRegistry: MeterRegistry) {
.meterFilter(MeterFilter.acceptNameStartsWith("tasks."))
.meterFilter(MeterFilter.deny()) // Currently don't want any other metrics
}
install(KtorServerTracing) {
setOpenTelemetry(openTelemetry)
capturedRequestHeaders("X-Amz-Cf-Id")
}
}

internal object BeforeCall : Hook<suspend (ApplicationCall, suspend () -> Unit) -> Unit> {
Expand All @@ -67,6 +95,7 @@ val addServiceUserUsernameToMDC = createRouteScopedPlugin("AddUsernameToMdc") {
val principal = call.principal<DeltaLdapPrincipal>(DELTA_AD_LDAP_SERVICE_USERS_AUTH_NAME) ?: return@on proceed()
val mdcContextMap = MDC.getCopyOfContextMap() ?: mutableMapOf()
mdcContextMap["username"] = principal.username
Span.current().setAttribute("delta.username", principal.username)
withContext(MDCContext(mdcContextMap)) {
proceed()
}
Expand All @@ -78,6 +107,7 @@ val addClientIdToMDC = createRouteScopedPlugin("AddClientIdToMDC") {
val principal = call.principal<ClientPrincipal>(CLIENT_HEADER_AUTH_NAME) ?: return@on proceed()
val mdcContextMap = MDC.getCopyOfContextMap() ?: mutableMapOf()
mdcContextMap["clientId"] = principal.client.clientId
Span.current().setAttribute("delta.clientId", principal.client.clientId)
withContext(MDCContext(mdcContextMap)) {
proceed()
}
Expand All @@ -92,8 +122,66 @@ val addBearerSessionInfoToMDC = createRouteScopedPlugin("AddBearerSessionInfoToM
mdcContextMap["userGUID"] = session.userGUID.toString()
mdcContextMap["oauthSession"] = session.id.toString()
mdcContextMap["trace"] = session.traceId
val span = Span.current()
span.setAttribute("delta.username", session.userCn ?: "")
span.setAttribute("enduser.id", session.userGUID.toString())
span.setAttribute("delta.oauthSession", session.id.toString())
span.setAttribute("delta.trace", session.traceId)
withContext(MDCContext(mdcContextMap)) {
proceed()
}
}
}

fun initOpenTelemetry(tracingConfig: TracingConfig): OpenTelemetry {
val resource = Resource.getDefault().toBuilder()
.put(AttributeKey.stringKey("service.name"), tracingConfig.serviceName ?: "DISABLED")
.build()

var openTelemetryBuilder = OpenTelemetrySdk.builder()
// Propagate the X-Ray trace header
.setPropagators(
ContextPropagators.create(
TextMapPropagator.composite(
W3CTraceContextPropagator.getInstance(), AwsXrayPropagator.getInstance()
)
)
)

if (tracingConfig.enabled) {
openTelemetryBuilder = openTelemetryBuilder.setTracerProvider(
SdkTracerProvider.builder()
.setResource(resource)
.addSpanProcessor(
BatchSpanProcessor.builder(OtlpGrpcSpanExporter.getDefault()).build()
)
.setSampler(Sampler.parentBased(NoHealthChecksSampler()))
// Generate X-Ray compliant span IDs
.setIdGenerator(AwsXrayIdGenerator.getInstance())
.build()
)
}

return openTelemetryBuilder.buildAndRegisterGlobal()
}

class NoHealthChecksSampler : Sampler {
private val pathKey = AttributeKey.stringKey("url.path")

override fun shouldSample(
parentContext: Context,
traceId: String,
name: String,
spanKind: SpanKind,
attributes: Attributes,
parentLinks: MutableList<LinkData>
): SamplingResult {
return if (attributes.get(pathKey) == "/health") SamplingResult.drop() else SamplingResult.recordAndSample()
}

override fun getDescription(): String {
return "Custom sampler that excludes requests to /health"
}
}

typealias SpanFactory = (String) -> SpanBuilder
Loading

0 comments on commit 9207645

Please sign in to comment.