Skip to content

Commit

Permalink
[KYUUBI #4323] Improve trino session context
Browse files Browse the repository at this point in the history
### _Why are the changes needed?_

This pr improves the trino session context:
1. always reuse the kyuubi session if session id exists, so we can restore the session context for next query
2. transform trino client information to kyuubi session, e.g. trino request source (trino-cli)

### _How was this patch tested?_
- [x] Add some test cases that check the changes thoroughly including negative and positive cases if possible

- [ ] Add screenshots for manual tests if appropriate

- [x] [Run test](https://kyuubi.readthedocs.io/en/master/develop_tools/testing.html#running-tests) locally before make a pull request

Closes #4323 from ulysses-you/trino-session.

Closes #4323

59804ba [ulysses-you] style
fcf540a [ulysses-you] Improve trino session context

Authored-by: ulysses-you <ulyssesyou18@gmail.com>
Signed-off-by: ulyssesyou <ulyssesyou@apache.org>
  • Loading branch information
ulysses-you committed Feb 14, 2023
1 parent 9ce5ef6 commit 763c088
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,16 @@ import java.util.concurrent.atomic.AtomicLong
import javax.ws.rs.WebApplicationException
import javax.ws.rs.core.{Response, UriInfo}

import scala.collection.mutable

import Slug.Context.{EXECUTING_QUERY, QUEUED_QUERY}
import com.google.common.hash.Hashing
import io.trino.client.QueryResults
import org.apache.hive.service.rpc.thrift.TProtocolVersion

import org.apache.kyuubi.operation.{FetchOrientation, OperationHandle}
import org.apache.kyuubi.operation.OperationState.{FINISHED, INITIALIZED, OperationState, PENDING}
import org.apache.kyuubi.server.trino.api.Query.KYUUBI_SESSION_ID
import org.apache.kyuubi.service.BackendService
import org.apache.kyuubi.session.SessionHandle

Expand Down Expand Up @@ -90,7 +93,7 @@ case class Query(

private def clear = {
be.closeOperation(queryId.operationHandle)
context.session.get("sessionId").foreach { id =>
context.session.get(KYUUBI_SESSION_ID).foreach { id =>
be.closeSession(SessionHandle.fromUUID(id))
}
}
Expand Down Expand Up @@ -128,39 +131,64 @@ case class Query(

object Query {

val KYUUBI_SESSION_ID = "kyuubi.session.id"

def apply(
statement: String,
context: TrinoContext,
translator: KyuubiTrinoOperationTranslator,
backendService: BackendService,
queryTimeout: Long = 0): Query = {

val sessionHandle = createSession(context, backendService)
val sessionHandle = getOrCreateSession(context, backendService)
val operationHandle = translator.transform(
statement,
sessionHandle,
context.session,
true,
queryTimeout)
val newSessionProperties =
context.session + ("sessionId" -> sessionHandle.identifier.toString)
val updatedContext = context.copy(session = newSessionProperties)
val sessionWithId =
context.session + (KYUUBI_SESSION_ID -> sessionHandle.identifier.toString)
val updatedContext = context.copy(session = sessionWithId)
Query(QueryId(operationHandle), updatedContext, backendService)
}

def apply(id: String, context: TrinoContext, backendService: BackendService): Query = {
Query(QueryId(id), context, backendService)
}

private def createSession(
private def getOrCreateSession(
context: TrinoContext,
backendService: BackendService): SessionHandle = {
backendService.openSession(
TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V11,
context.user,
"",
context.remoteUserAddress.getOrElse(""),
context.session)
context.session.get(KYUUBI_SESSION_ID).map(SessionHandle.fromUUID).getOrElse {
// transform Trino information to session and engine as far as possible.
val trinoInfo = new mutable.HashMap[String, String]()
context.clientInfo.foreach { info =>
trinoInfo.put("trino.client.info", info)
}
context.source.foreach { source =>
trinoInfo.put("trino.request.source", source)
}
context.traceToken.foreach { traceToken =>
trinoInfo.put("trino.trace.token", traceToken)
}
context.timeZone.foreach { timeZone =>
trinoInfo.put("trino.time.zone", timeZone)
}
context.language.foreach { language =>
trinoInfo.put("trino.language", language)
}
if (context.clientTags.nonEmpty) {
trinoInfo.put("trino.client.info", context.clientTags.mkString(","))
}

val newSessionConfigs = context.session ++ trinoInfo
backendService.openSession(
TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V11,
context.user,
"",
context.remoteUserAddress.getOrElse(""),
newSessionConfigs)
}
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ import org.apache.hive.service.rpc.thrift.{TGetResultSetMetadataResp, TRowSet, T

import org.apache.kyuubi.operation.OperationState.FINISHED
import org.apache.kyuubi.operation.OperationStatus
import org.apache.kyuubi.server.trino.api.Query.KYUUBI_SESSION_ID

// TODO: Support replace `preparedStatement` for Trino-jdbc
/**
* The description and functionality of trino request
* and response's context
Expand Down Expand Up @@ -140,15 +142,17 @@ object TrinoContext {
def buildTrinoResponse(qr: QueryResults, trinoContext: TrinoContext): Response = {
val responseBuilder = Response.ok(qr)

trinoContext.catalog.foreach(
responseBuilder.header(TRINO_HEADERS.responseSetCatalog, _))
trinoContext.schema.foreach(
responseBuilder.header(TRINO_HEADERS.responseSetSchema, _))
// Note, We have injected kyuubi session id to session context so that the next query can find
// the previous session to restore the query context.
// It's hard to follow the Trino style that set all context to http headers.
// Because we do not know the context at server side. e.g. `set k=v`, `use database`.
// We also can not inject other session context into header before we supporting to map
// query result to session context.
require(trinoContext.session.contains(KYUUBI_SESSION_ID), s"$KYUUBI_SESSION_ID must be set.")
responseBuilder.header(
TRINO_HEADERS.responseSetSession,
s"$KYUUBI_SESSION_ID=${urlEncode(trinoContext.session(KYUUBI_SESSION_ID))}")

trinoContext.session.foreach {
case (k, v) =>
responseBuilder.header(TRINO_HEADERS.responseSetSession, s"${k}=${urlEncode(v)}")
}
trinoContext.preparedStatement.foreach {
case (k, v) =>
responseBuilder.header(TRINO_HEADERS.responseAddedPrepare, s"${k}=${urlEncode(v)}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,9 +215,10 @@ private[v1] class StatementResource extends ApiRequestContext with Logging {
slug: String,
token: Long,
slugContext: Slug.Context.Context): Try[Query] = {

Try(be.sessionManager.operationManager.getOperation(queryId.operationHandle)).map { _ =>
Query(queryId, context, be)
Try(be.sessionManager.operationManager.getOperation(queryId.operationHandle)).map { op =>
val sessionWithId = context.session ++
Map(Query.KYUUBI_SESSION_ID -> op.getSession.handle.identifier.toString)
Query(queryId, context.copy(session = sessionWithId), be)
}.filter(_.getSlug.isValid(slugContext, slug, token))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,24 @@ class TrinoClientApiSuite extends KyuubiFunSuite with TrinoRestFrontendTestHelpe
test("submit query with trino client api") {
val trino = getTrinoStatementClient("select 1")
val result = execute(trino)
val sessionId = trino.getSetSessionProperties.asScala.get("sessionId")
val sessionId = trino.getSetSessionProperties.asScala.get(Query.KYUUBI_SESSION_ID)
assert(result == List(List(1)))

updateClientSession(trino)

val trino1 = getTrinoStatementClient("select 2")
val trino1 = getTrinoStatementClient("set k=v")
val result1 = execute(trino1)
val sessionId1 = trino1.getSetSessionProperties.asScala.get("sessionId")
assert(result1 == List(List(2)))
assert(sessionId != sessionId1)
val sessionId1 = trino1.getSetSessionProperties.asScala.get(Query.KYUUBI_SESSION_ID)
assert(result1 == List(List("k", "v")))
assert(sessionId == sessionId1)

updateClientSession(trino)

val trino2 = getTrinoStatementClient("set k")
val result2 = execute(trino2)
val sessionId2 = trino2.getSetSessionProperties.asScala.get(Query.KYUUBI_SESSION_ID)
assert(result2 == List(List("k", "v")))
assert(sessionId == sessionId2)

trino.close()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import io.trino.client.ProtocolHeaders.TRINO_HEADERS

import org.apache.kyuubi.{KyuubiFunSuite, KyuubiSQLException, TrinoRestFrontendTestHelper}
import org.apache.kyuubi.operation.{OperationHandle, OperationState}
import org.apache.kyuubi.server.trino.api.TrinoContext
import org.apache.kyuubi.server.trino.api.{Query, TrinoContext}
import org.apache.kyuubi.server.trino.api.v1.dto.Ok
import org.apache.kyuubi.session.SessionHandle

Expand Down Expand Up @@ -78,7 +78,7 @@ class StatementResourceSuite extends KyuubiFunSuite with TrinoRestFrontendTestHe
response.getStringHeaders.get(TRINO_HEADERS.responseSetSession).asScala
.map(_.split("="))
.find {
case Array("sessionId", _) => true
case Array(Query.KYUUBI_SESSION_ID, _) => true
}
.map {
case Array(_, value) => SessionHandle.fromUUID(TrinoContext.urlDecode(value))
Expand All @@ -90,12 +90,12 @@ class StatementResourceSuite extends KyuubiFunSuite with TrinoRestFrontendTestHe
val path = qr.getNextUri.getPath
val nextResponse = webTarget.path(path).request().header(
TRINO_HEADERS.requestSession(),
s"sessionId=${TrinoContext.urlEncode(sessionHandle.identifier.toString)}").delete()
s"${Query.KYUUBI_SESSION_ID}=${TrinoContext.urlEncode(sessionHandle.identifier.toString)}")
.delete()
assert(nextResponse.getStatus == 204)
assert(operation.getStatus.state == OperationState.CLOSED)
val exception = intercept[KyuubiSQLException](sessionManager.getSession(sessionHandle))
assert(exception.getMessage === s"Invalid $sessionHandle")

}

private def getData(current: TrinoResponse): TrinoResponse = {
Expand Down

0 comments on commit 763c088

Please sign in to comment.