diff --git a/client/scala/scala-armada-client/.sbtopts b/client/scala/scala-armada-client/.sbtopts new file mode 100644 index 00000000000..ead90e50de5 --- /dev/null +++ b/client/scala/scala-armada-client/.sbtopts @@ -0,0 +1 @@ +-J-Dsbt.io.implicit.relative.glob.conversion=allow diff --git a/client/scala/scala-armada-client/build.sbt b/client/scala/scala-armada-client/build.sbt index 93e27000aa5..4b9b08f46bd 100644 --- a/client/scala/scala-armada-client/build.sbt +++ b/client/scala/scala-armada-client/build.sbt @@ -21,5 +21,6 @@ Compile / PB.protoSources ++= Seq(file("./proto")) libraryDependencies ++= Seq( "io.grpc" % "grpc-netty" % scalapb.compiler.Version.grpcJavaVersion, - "com.thesamet.scalapb" %% "scalapb-runtime-grpc" % scalapb.compiler.Version.scalapbVersion + "com.thesamet.scalapb" %% "scalapb-runtime-grpc" % scalapb.compiler.Version.scalapbVersion, + "com.github.jkugiya" %% "ulid-scala" % "1.0.5" ) diff --git a/client/scala/scala-armada-client/src/test/scala/ArmadaClientSuite.scala b/client/scala/scala-armada-client/src/test/scala/ArmadaClientSuite.scala index c533a8d2900..56fd56fe60c 100644 --- a/client/scala/scala-armada-client/src/test/scala/ArmadaClientSuite.scala +++ b/client/scala/scala-armada-client/src/test/scala/ArmadaClientSuite.scala @@ -2,7 +2,7 @@ package io.armadaproject.armada import io.armadaproject.armada.ArmadaClient import api.submit.{SubmitGrpc, CancellationResult, Queue, BatchQueueCreateResponse, - StreamingQueueMessage, JobReprioritizeResponse, JobSubmitResponse, + StreamingQueueMessage, Job, JobReprioritizeResponse, JobSubmitResponse, BatchQueueUpdateResponse, JobSubmitResponseItem, JobSubmitRequestItem, JobState, JobSetCancelRequest, JobCancelRequest, QueueDeleteRequest, QueueGetRequest, StreamingQueueGetRequest, JobPreemptRequest, @@ -15,10 +15,14 @@ import com.google.protobuf.empty.Empty import api.health.HealthCheckResponse import api.event.{EventGrpc, EventStreamMessage, JobSetRequest, WatchRequest} import io.grpc.stub.StreamObserver -import io.grpc.{Server, ServerBuilder} +import io.grpc.{Server, ServerBuilder, Status, StatusRuntimeException} +import jkugiya.ulid.ULID + +import java.util.concurrent.ConcurrentHashMap import scala.concurrent.Future import scala.util.Random +import javax.print.attribute.standard.JobPriority private class EventMockServer extends EventGrpc.Event { override def health(empty: Empty): scala.concurrent.Future[HealthCheckResponse] = { @@ -35,7 +39,11 @@ private class EventMockServer extends EventGrpc.Event { } } -private class SubmitMockServer extends SubmitGrpc.Submit { +private class SubmitMockServer(jobMap: ConcurrentHashMap[String, Job], queueMap: ConcurrentHashMap[String, Queue]) + extends SubmitGrpc.Submit { + + val ulidGen = ULID.getGenerator() + def cancelJobSet(request: JobSetCancelRequest): scala.concurrent.Future[com.google.protobuf.empty.Empty] = { Future.successful(new Empty) } @@ -45,18 +53,27 @@ private class SubmitMockServer extends SubmitGrpc.Submit { } def createQueue(request: Queue): scala.concurrent.Future[com.google.protobuf.empty.Empty] = { + queueMap.put(request.name, request) Future.successful(new Empty) } def createQueues(request: QueueList): scala.concurrent.Future[BatchQueueCreateResponse] = { + request.queues.foreach { q => queueMap.put(q.name, q) } Future.successful(new BatchQueueCreateResponse) } def deleteQueue(request: QueueDeleteRequest): scala.concurrent.Future[com.google.protobuf.empty.Empty] = { + queueMap.remove(request.name) Future.successful(new Empty) } + def getQueue(request: QueueGetRequest): scala.concurrent.Future[Queue] = { - Future.successful(new Queue) + val q = queueMap.get(request.name) + if (q == null) { + Future.failed(new StatusRuntimeException(Status.NOT_FOUND)) + } else { + Future.successful(q) + } } def getQueues(request: StreamingQueueGetRequest, responseObserver: io.grpc.stub.StreamObserver[StreamingQueueMessage]): Unit = { @@ -76,7 +93,17 @@ private class SubmitMockServer extends SubmitGrpc.Submit { } def submitJobs(request: JobSubmitRequest): scala.concurrent.Future[JobSubmitResponse] = { - Future.successful((new JobSubmitResponse(List(JobSubmitResponseItem("fakeJobId"))))) + val q = queueMap.get(request.queue) + if (q == null) { + val msg = "could not find queue \"" + request.queue + "\"" + return Future.failed(new StatusRuntimeException(Status.PERMISSION_DENIED.withDescription(msg))) + } + + val jobId: String = ulidGen.base32().toLowerCase() + val newJob = new Job() + jobMap.put(jobId, newJob) + + Future.successful(new JobSubmitResponse(List(JobSubmitResponseItem(jobId)))) } def updateQueue(request: Queue): scala.concurrent.Future[com.google.protobuf.empty.Empty] = { @@ -86,10 +113,9 @@ private class SubmitMockServer extends SubmitGrpc.Submit { def updateQueues(request: QueueList): scala.concurrent.Future[BatchQueueUpdateResponse] = { Future.successful(new BatchQueueUpdateResponse) } - } -private class JobsMockServer extends JobsGrpc.Jobs { +private class JobsMockServer(jobMap: ConcurrentHashMap[String, Job]) extends JobsGrpc.Jobs { def getJobDetails(request: JobDetailsRequest): scala.concurrent.Future[JobDetailsResponse] = { Future.successful(new JobDetailsResponse) } @@ -103,8 +129,10 @@ private class JobsMockServer extends JobsGrpc.Jobs { } def getJobStatus(request: JobStatusRequest): scala.concurrent.Future[JobStatusResponse] = { - val response = new JobStatusResponse(Map("fakeJobId" -> JobState.RUNNING)) - Future.successful(response) + val statusMap = collection.mutable.Map[String,JobState]() // jobID -> state + jobMap.keySet.forEach(k => statusMap.put(k, JobState.RUNNING)) + + Future.successful(new JobStatusResponse(statusMap.to(collection.immutable.Map))) } def getJobStatusUsingExternalJobUri(request: JobStatusUsingExternalJobUriRequest): scala.concurrent.Future[JobStatusResponse] = { @@ -119,13 +147,17 @@ class ArmadaClientSuite extends munit.FunSuite { val mockEventServer = new Fixture[Server]("Event GRPC Mock Server") { private var server: Server = null def apply() = server + + private val jobMap: ConcurrentHashMap[String, Job] = new ConcurrentHashMap // key is job id + private val queueMap: ConcurrentHashMap[String, Queue] = new ConcurrentHashMap // key is queue name + override def beforeAll(): Unit = { import scala.concurrent.ExecutionContext server = ServerBuilder .forPort(testPort) .addService(EventGrpc.bindService(new EventMockServer, ExecutionContext.global)) - .addService(SubmitGrpc.bindService(new SubmitMockServer, ExecutionContext.global)) - .addService(JobsGrpc.bindService(new JobsMockServer, ExecutionContext.global)) + .addService(SubmitGrpc.bindService(new SubmitMockServer(jobMap, queueMap), ExecutionContext.global)) + .addService(JobsGrpc.bindService(new JobsMockServer(jobMap), ExecutionContext.global)) .build() .start() } @@ -136,52 +168,71 @@ class ArmadaClientSuite extends munit.FunSuite { override def munitFixtures = List(mockEventServer) - test("ArmadaClient.EventHealth()") { + test("ArmadaClient.eventHealth()") { val ac = ArmadaClient("localhost", testPort) - val status = ac.EventHealth() + val status = ac.eventHealth() assertEquals(status, HealthCheckResponse.ServingStatus.SERVING) } - test("ArmadaClient.SubmitHealth()") { + test("ArmadaClient.submitHealth()") { val ac = ArmadaClient("localhost", testPort) - val status = ac.SubmitHealth() + val status = ac.submitHealth() assertEquals(status, HealthCheckResponse.ServingStatus.SERVING) } - test("ArmadaClient.SubmitJobs()") { + test("ArmadaClient.submitJobs()") { + val ac = ArmadaClient("localhost", testPort) + + // submission to non-existent queue + val qName = "nonexistent-queue-" + Random.alphanumeric.take(8).mkString + var response: JobSubmitResponse = new JobSubmitResponse() + + intercept[StatusRuntimeException] { + response = ac.submitJobs(qName, "testJobSetId", List(new JobSubmitRequestItem())) + } + assertEquals(0, response.jobResponseItems.length) + + // submission to existing queue + ac.createQueue(qName) + response = ac.submitJobs(qName, "testJobSetId", List(new JobSubmitRequestItem())) + assertEquals(1, response.jobResponseItems.length) + ac.deleteQueue(qName) + } + + test("ArmadaClient.getJobStatus()") { val ac = ArmadaClient("localhost", testPort) - val response = ac.SubmitJobs("testQueue", "testJobSetId", List(new JobSubmitRequestItem())) - assertEquals(response.jobResponseItems(0), JobSubmitResponseItem("fakeJobId")) + val qName = "getjobstatus-test-queue-" + Random.alphanumeric.take(8).mkString + + ac.createQueue(qName) + val newJob = ac.submitJobs(qName, "testJobSetId", List(new JobSubmitRequestItem())) + + val jobId = newJob.jobResponseItems(0).jobId + val jobStatus = ac.getJobStatus(jobId) + assert(jobStatus.jobStates(jobId).isRunning) + + ac.deleteQueue(qName) } - test("ArmadaClient.GetJobStatus()") { + test("ArmadaClient.{get,create,delete}Queue()") { val ac = ArmadaClient("localhost", testPort) - val response = ac.GetJobStatus("fakeJobId") - assert(response.jobStates("fakeJobId").isRunning) - } - - // Queue tests currently disabled - Armada mock server does not implement full queue - // state so these fail when running with mock; they pass with a real Armada instance - // test("test queue existence, creation, deletion") { - // val ac = new ArmadaClient(ArmadaClient.GetChannel("localhost", testPort)) - // val qName = "test-queue-" + Random.alphanumeric.take(8).mkString - // var q: Queue = new Queue() - - // // queue should not exist yet - // intercept[io.grpc.StatusRuntimeException] { - // q = ac.getQueue(qName) - // } - // assertNotEquals(q.name, qName) - - // ac.createQueue(qName) - // q = ac.getQueue(qName) - // assertEquals(q.name, qName) - - // ac.deleteQueue(qName) - // q = new Queue() - // intercept[io.grpc.StatusRuntimeException] { - // q = ac.getQueue(qName) - // } - // assertNotEquals(q.name, qName) - // } + val qName = "test-queue-" + Random.alphanumeric.take(8).mkString + var q: Queue = new Queue() + + // queue should not exist yet + intercept[StatusRuntimeException] { + q = ac.getQueue(qName) + } + assertNotEquals(q.name, qName) + + ac.createQueue(qName) + q = ac.getQueue(qName) + assertEquals(q.name, qName) + + ac.deleteQueue(qName) + q = new Queue() + intercept[StatusRuntimeException] { + q = ac.getQueue(qName) + } + assertNotEquals(q.name, qName) + } }