Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scala client: enhance mock Armada server for unit tests #4214

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions client/scala/scala-armada-client/.sbtopts
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
-J-Dsbt.io.implicit.relative.glob.conversion=allow
3 changes: 2 additions & 1 deletion client/scala/scala-armada-client/build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,6 @@ Compile / PB.protoSources ++= Seq(file("./proto"))

libraryDependencies ++= Seq(
"io.grpc" % "grpc-netty" % scalapb.compiler.Version.grpcJavaVersion,
"com.thesamet.scalapb" %% "scalapb-runtime-grpc" % scalapb.compiler.Version.scalapbVersion
"com.thesamet.scalapb" %% "scalapb-runtime-grpc" % scalapb.compiler.Version.scalapbVersion,
"com.github.jkugiya" %% "ulid-scala" % "1.0.5"
)
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package io.armadaproject.armada

import io.armadaproject.armada.ArmadaClient
import api.submit.{SubmitGrpc, CancellationResult, Queue, BatchQueueCreateResponse,
StreamingQueueMessage, JobReprioritizeResponse, JobSubmitResponse,
StreamingQueueMessage, Job, JobReprioritizeResponse, JobSubmitResponse,
BatchQueueUpdateResponse, JobSubmitResponseItem, JobSubmitRequestItem,
JobState, JobSetCancelRequest, JobCancelRequest, QueueDeleteRequest,
QueueGetRequest, StreamingQueueGetRequest, JobPreemptRequest,
Expand All @@ -15,10 +15,14 @@ import com.google.protobuf.empty.Empty
import api.health.HealthCheckResponse
import api.event.{EventGrpc, EventStreamMessage, JobSetRequest, WatchRequest}
import io.grpc.stub.StreamObserver
import io.grpc.{Server, ServerBuilder}
import io.grpc.{Server, ServerBuilder, Status, StatusRuntimeException}
import jkugiya.ulid.ULID

import java.util.concurrent.ConcurrentHashMap

import scala.concurrent.Future
import scala.util.Random
import javax.print.attribute.standard.JobPriority

private class EventMockServer extends EventGrpc.Event {
override def health(empty: Empty): scala.concurrent.Future[HealthCheckResponse] = {
Expand All @@ -35,7 +39,11 @@ private class EventMockServer extends EventGrpc.Event {
}
}

private class SubmitMockServer extends SubmitGrpc.Submit {
private class SubmitMockServer(jobMap: ConcurrentHashMap[String, Job], queueMap: ConcurrentHashMap[String, Queue])
extends SubmitGrpc.Submit {

val ulidGen = ULID.getGenerator()

def cancelJobSet(request: JobSetCancelRequest): scala.concurrent.Future[com.google.protobuf.empty.Empty] = {
Future.successful(new Empty)
}
Expand All @@ -45,18 +53,27 @@ private class SubmitMockServer extends SubmitGrpc.Submit {
}

def createQueue(request: Queue): scala.concurrent.Future[com.google.protobuf.empty.Empty] = {
queueMap.put(request.name, request)
Future.successful(new Empty)
}

def createQueues(request: QueueList): scala.concurrent.Future[BatchQueueCreateResponse] = {
request.queues.foreach { q => queueMap.put(q.name, q) }
Future.successful(new BatchQueueCreateResponse)
}

def deleteQueue(request: QueueDeleteRequest): scala.concurrent.Future[com.google.protobuf.empty.Empty] = {
queueMap.remove(request.name)
Future.successful(new Empty)
}

def getQueue(request: QueueGetRequest): scala.concurrent.Future[Queue] = {
Future.successful(new Queue)
val q = queueMap.get(request.name)
if (q == null) {
Future.failed(new StatusRuntimeException(Status.NOT_FOUND))
} else {
Future.successful(q)
}
}

def getQueues(request: StreamingQueueGetRequest, responseObserver: io.grpc.stub.StreamObserver[StreamingQueueMessage]): Unit = {
Expand All @@ -76,7 +93,17 @@ private class SubmitMockServer extends SubmitGrpc.Submit {
}

def submitJobs(request: JobSubmitRequest): scala.concurrent.Future[JobSubmitResponse] = {
Future.successful((new JobSubmitResponse(List(JobSubmitResponseItem("fakeJobId")))))
val q = queueMap.get(request.queue)
if (q == null) {
val msg = "could not find queue \"" + request.queue + "\""
return Future.failed(new StatusRuntimeException(Status.PERMISSION_DENIED.withDescription(msg)))
}

val jobId: String = ulidGen.base32().toLowerCase()
val newJob = new Job()
jobMap.put(jobId, newJob)

Future.successful(new JobSubmitResponse(List(JobSubmitResponseItem(jobId))))
}

def updateQueue(request: Queue): scala.concurrent.Future[com.google.protobuf.empty.Empty] = {
Expand All @@ -86,10 +113,9 @@ private class SubmitMockServer extends SubmitGrpc.Submit {
def updateQueues(request: QueueList): scala.concurrent.Future[BatchQueueUpdateResponse] = {
Future.successful(new BatchQueueUpdateResponse)
}

}

private class JobsMockServer extends JobsGrpc.Jobs {
private class JobsMockServer(jobMap: ConcurrentHashMap[String, Job]) extends JobsGrpc.Jobs {
def getJobDetails(request: JobDetailsRequest): scala.concurrent.Future[JobDetailsResponse] = {
Future.successful(new JobDetailsResponse)
}
Expand All @@ -103,8 +129,10 @@ private class JobsMockServer extends JobsGrpc.Jobs {
}

def getJobStatus(request: JobStatusRequest): scala.concurrent.Future[JobStatusResponse] = {
val response = new JobStatusResponse(Map("fakeJobId" -> JobState.RUNNING))
Future.successful(response)
val statusMap = collection.mutable.Map[String,JobState]() // jobID -> state
jobMap.keySet.forEach(k => statusMap.put(k, JobState.RUNNING))

Future.successful(new JobStatusResponse(statusMap.to(collection.immutable.Map)))
}

def getJobStatusUsingExternalJobUri(request: JobStatusUsingExternalJobUriRequest): scala.concurrent.Future[JobStatusResponse] = {
Expand All @@ -119,13 +147,17 @@ class ArmadaClientSuite extends munit.FunSuite {
val mockEventServer = new Fixture[Server]("Event GRPC Mock Server") {
private var server: Server = null
def apply() = server

private val jobMap: ConcurrentHashMap[String, Job] = new ConcurrentHashMap // key is job id
private val queueMap: ConcurrentHashMap[String, Queue] = new ConcurrentHashMap // key is queue name

override def beforeAll(): Unit = {
import scala.concurrent.ExecutionContext
server = ServerBuilder
.forPort(testPort)
.addService(EventGrpc.bindService(new EventMockServer, ExecutionContext.global))
.addService(SubmitGrpc.bindService(new SubmitMockServer, ExecutionContext.global))
.addService(JobsGrpc.bindService(new JobsMockServer, ExecutionContext.global))
.addService(SubmitGrpc.bindService(new SubmitMockServer(jobMap, queueMap), ExecutionContext.global))
.addService(JobsGrpc.bindService(new JobsMockServer(jobMap), ExecutionContext.global))
.build()
.start()
}
Expand All @@ -136,52 +168,71 @@ class ArmadaClientSuite extends munit.FunSuite {

override def munitFixtures = List(mockEventServer)

test("ArmadaClient.EventHealth()") {
test("ArmadaClient.eventHealth()") {
val ac = ArmadaClient("localhost", testPort)
val status = ac.EventHealth()
val status = ac.eventHealth()
assertEquals(status, HealthCheckResponse.ServingStatus.SERVING)
}

test("ArmadaClient.SubmitHealth()") {
test("ArmadaClient.submitHealth()") {
val ac = ArmadaClient("localhost", testPort)
val status = ac.SubmitHealth()
val status = ac.submitHealth()
assertEquals(status, HealthCheckResponse.ServingStatus.SERVING)
}

test("ArmadaClient.SubmitJobs()") {
test("ArmadaClient.submitJobs()") {
val ac = ArmadaClient("localhost", testPort)

// submission to non-existent queue
val qName = "nonexistent-queue-" + Random.alphanumeric.take(8).mkString
var response: JobSubmitResponse = new JobSubmitResponse()

intercept[StatusRuntimeException] {
response = ac.submitJobs(qName, "testJobSetId", List(new JobSubmitRequestItem()))
}
assertEquals(0, response.jobResponseItems.length)

// submission to existing queue
ac.createQueue(qName)
response = ac.submitJobs(qName, "testJobSetId", List(new JobSubmitRequestItem()))
assertEquals(1, response.jobResponseItems.length)
ac.deleteQueue(qName)
}

test("ArmadaClient.getJobStatus()") {
val ac = ArmadaClient("localhost", testPort)
val response = ac.SubmitJobs("testQueue", "testJobSetId", List(new JobSubmitRequestItem()))
assertEquals(response.jobResponseItems(0), JobSubmitResponseItem("fakeJobId"))
val qName = "getjobstatus-test-queue-" + Random.alphanumeric.take(8).mkString

ac.createQueue(qName)
val newJob = ac.submitJobs(qName, "testJobSetId", List(new JobSubmitRequestItem()))

val jobId = newJob.jobResponseItems(0).jobId
val jobStatus = ac.getJobStatus(jobId)
assert(jobStatus.jobStates(jobId).isRunning)

ac.deleteQueue(qName)
}

test("ArmadaClient.GetJobStatus()") {
test("ArmadaClient.{get,create,delete}Queue()") {
val ac = ArmadaClient("localhost", testPort)
val response = ac.GetJobStatus("fakeJobId")
assert(response.jobStates("fakeJobId").isRunning)
}

// Queue tests currently disabled - Armada mock server does not implement full queue
// state so these fail when running with mock; they pass with a real Armada instance
// test("test queue existence, creation, deletion") {
// val ac = new ArmadaClient(ArmadaClient.GetChannel("localhost", testPort))
// val qName = "test-queue-" + Random.alphanumeric.take(8).mkString
// var q: Queue = new Queue()

// // queue should not exist yet
// intercept[io.grpc.StatusRuntimeException] {
// q = ac.getQueue(qName)
// }
// assertNotEquals(q.name, qName)

// ac.createQueue(qName)
// q = ac.getQueue(qName)
// assertEquals(q.name, qName)

// ac.deleteQueue(qName)
// q = new Queue()
// intercept[io.grpc.StatusRuntimeException] {
// q = ac.getQueue(qName)
// }
// assertNotEquals(q.name, qName)
// }
val qName = "test-queue-" + Random.alphanumeric.take(8).mkString
var q: Queue = new Queue()

// queue should not exist yet
intercept[StatusRuntimeException] {
q = ac.getQueue(qName)
}
assertNotEquals(q.name, qName)

ac.createQueue(qName)
q = ac.getQueue(qName)
assertEquals(q.name, qName)

ac.deleteQueue(qName)
q = new Queue()
intercept[StatusRuntimeException] {
q = ac.getQueue(qName)
}
assertNotEquals(q.name, qName)
}
}
Loading