From 4db5dc83cf83541c3bf4fe7a18bc24eb7521c2b7 Mon Sep 17 00:00:00 2001 From: Tanzim Hossain Date: Mon, 15 Jul 2024 17:05:34 -0400 Subject: [PATCH] Add Vertex AI unit tests (#6090) --- .github/workflows/ci_tests.yml | 5 + .gitignore | 1 + README.md | 5 + firebase-vertexai/README.md | 4 + .../firebase-vertexai.gradle.kts | 11 +- .../vertexai/StreamingSnapshotTests.kt | 187 +++++++++++ .../firebase/vertexai/UnarySnapshotTests.kt | 295 ++++++++++++++++++ .../google/firebase/vertexai/util/kotlin.kt | 35 +++ .../google/firebase/vertexai/util/tests.kt | 182 +++++++++++ .../src/test/resources/README.md | 2 + firebase-vertexai/update_responses.sh | 22 ++ 11 files changed, 747 insertions(+), 2 deletions(-) create mode 100644 firebase-vertexai/src/test/java/com/google/firebase/vertexai/StreamingSnapshotTests.kt create mode 100644 firebase-vertexai/src/test/java/com/google/firebase/vertexai/UnarySnapshotTests.kt create mode 100644 firebase-vertexai/src/test/java/com/google/firebase/vertexai/util/kotlin.kt create mode 100644 firebase-vertexai/src/test/java/com/google/firebase/vertexai/util/tests.kt create mode 100644 firebase-vertexai/src/test/resources/README.md create mode 100755 firebase-vertexai/update_responses.sh diff --git a/.github/workflows/ci_tests.yml b/.github/workflows/ci_tests.yml index 42d45537628..508d5d6f8e5 100644 --- a/.github/workflows/ci_tests.yml +++ b/.github/workflows/ci_tests.yml @@ -64,6 +64,11 @@ jobs: ./gradlew :common:updateVersion common:publishToMavenLocal cd .. + - name: Clone mock responses + if: matrix.module == ':firebase-vertexai' + run: | + firebase-vertexai/update_responses.sh + - name: Add google-services.json env: INTEG_TESTS_GOOGLE_SERVICES: ${{ secrets.INTEG_TESTS_GOOGLE_SERVICES }} diff --git a/.gitignore b/.gitignore index d76e10d3c80..a978cc8a8a1 100644 --- a/.gitignore +++ b/.gitignore @@ -13,6 +13,7 @@ smoke-test-logs/ smoke-tests/build-debug-headGit-smoke-test smoke-tests/firehorn.log macrobenchmark-output.json +vertexai-sdk-test-data/ # generated Terraform docs .terraform/* diff --git a/README.md b/README.md index 4e6500c4820..d9154b826b0 100644 --- a/README.md +++ b/README.md @@ -63,6 +63,11 @@ Unit tests can be executed on the command line by running ./gradlew ::check ``` +#### Vertex AI for Firebase + +See the Vertex AI for Firebase [README](firebase-vertexai#running-tests) for setup +instructions specific to that project. + ### Integration Testing These are tests that run on a hardware device or emulator. These tests have diff --git a/firebase-vertexai/README.md b/firebase-vertexai/README.md index dcc31b8eef6..df5f5f15bf0 100644 --- a/firebase-vertexai/README.md +++ b/firebase-vertexai/README.md @@ -15,6 +15,10 @@ All Gradle commands should be run from the root of this repository. ## Running Tests +> [!IMPORTANT] +> These unit tests require mock response files, which can be downloaded by running +`./firebase-vertexai/update_responses.sh` from the root of this repository. + Unit tests: `./gradlew :firebase-vertexai:check` diff --git a/firebase-vertexai/firebase-vertexai.gradle.kts b/firebase-vertexai/firebase-vertexai.gradle.kts index 929f1ba1549..01f46b60a5b 100644 --- a/firebase-vertexai/firebase-vertexai.gradle.kts +++ b/firebase-vertexai/firebase-vertexai.gradle.kts @@ -49,7 +49,10 @@ android { targetCompatibility = JavaVersion.VERSION_1_8 } kotlinOptions { jvmTarget = "1.8" } - testOptions.unitTests.isIncludeAndroidResources = true + testOptions { + unitTests.isIncludeAndroidResources = true + unitTests.isReturnDefaultValues = true + } } dependencies { @@ -58,7 +61,7 @@ dependencies { implementation("com.google.firebase:firebase-components:18.0.0") implementation("com.google.firebase:firebase-annotations:16.2.0") implementation("com.google.firebase:firebase-appcheck-interop:17.1.0") - implementation("com.google.ai.client.generativeai:common:0.7.1") + implementation("com.google.ai.client.generativeai:common:0.9.0") implementation(libs.androidx.annotation) implementation("org.jetbrains.kotlinx:kotlinx-serialization-json:1.5.1") implementation("androidx.core:core-ktx:1.12.0") @@ -71,8 +74,12 @@ dependencies { implementation("androidx.concurrent:concurrent-futures-ktx:1.2.0-alpha03") implementation("com.google.firebase:firebase-auth-interop:18.0.0") + val ktorVersion = "2.3.2" testImplementation("io.kotest:kotest-assertions-core:5.5.5") testImplementation("io.kotest:kotest-assertions-core-jvm:5.5.5") + testImplementation("io.ktor:ktor-client-okhttp:$ktorVersion") + testImplementation("io.ktor:ktor-client-mock:$ktorVersion") + testImplementation("org.json:json:20240303") testImplementation(libs.androidx.test.junit) testImplementation(libs.androidx.test.runner) testImplementation(libs.junit) diff --git a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/StreamingSnapshotTests.kt b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/StreamingSnapshotTests.kt new file mode 100644 index 00000000000..499bd12c0b9 --- /dev/null +++ b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/StreamingSnapshotTests.kt @@ -0,0 +1,187 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.firebase.vertexai + +import com.google.firebase.vertexai.type.BlockReason +import com.google.firebase.vertexai.type.FinishReason +import com.google.firebase.vertexai.type.HarmCategory +import com.google.firebase.vertexai.type.InvalidAPIKeyException +import com.google.firebase.vertexai.type.PromptBlockedException +import com.google.firebase.vertexai.type.ResponseStoppedException +import com.google.firebase.vertexai.type.SerializationException +import com.google.firebase.vertexai.type.ServerException +import com.google.firebase.vertexai.type.TextPart +import com.google.firebase.vertexai.util.goldenStreamingFile +import io.kotest.assertions.throwables.shouldThrow +import io.kotest.matchers.nulls.shouldNotBeNull +import io.kotest.matchers.shouldBe +import io.kotest.matchers.string.shouldContain +import io.ktor.http.HttpStatusCode +import kotlin.time.Duration.Companion.seconds +import kotlinx.coroutines.flow.collect +import kotlinx.coroutines.flow.first +import kotlinx.coroutines.flow.toList +import kotlinx.coroutines.withTimeout +import org.junit.Test + +internal class StreamingSnapshotTests { + private val testTimeout = 5.seconds + + @Test + fun `short reply`() = + goldenStreamingFile("success-basic-reply-short.txt") { + val responses = model.generateContentStream("prompt") + + withTimeout(testTimeout) { + val responseList = responses.toList() + responseList.isEmpty() shouldBe false + responseList.first().candidates.first().finishReason shouldBe FinishReason.STOP + responseList.first().candidates.first().content.parts.isEmpty() shouldBe false + responseList.first().candidates.first().safetyRatings.isEmpty() shouldBe false + } + } + + @Test + fun `long reply`() = + goldenStreamingFile("success-basic-reply-long.txt") { + val responses = model.generateContentStream("prompt") + + withTimeout(testTimeout) { + val responseList = responses.toList() + responseList.isEmpty() shouldBe false + responseList.forEach { + it.candidates.first().finishReason shouldBe FinishReason.STOP + it.candidates.first().content.parts.isEmpty() shouldBe false + it.candidates.first().safetyRatings.isEmpty() shouldBe false + } + } + } + + @Test + fun `unknown enum`() = + goldenStreamingFile("success-unknown-enum.txt") { + val responses = model.generateContentStream("prompt") + + withTimeout(testTimeout) { + val responseList = responses.toList() + + responseList.isEmpty() shouldBe false + responseList.any { + it.candidates.any { it.safetyRatings.any { it.category == HarmCategory.UNKNOWN } } + } shouldBe true + } + } + + @Test + fun `quotes escaped`() = + goldenStreamingFile("success-quotes-escaped.txt") { + val responses = model.generateContentStream("prompt") + + withTimeout(testTimeout) { + val responseList = responses.toList() + + responseList.isEmpty() shouldBe false + val part = responseList.first().candidates.first().content.parts.first() as? TextPart + part.shouldNotBeNull() + part.text shouldContain "\"" + } + } + + @Test + fun `prompt blocked for safety`() = + goldenStreamingFile("failure-prompt-blocked-safety.txt") { + val responses = model.generateContentStream("prompt") + + withTimeout(testTimeout) { + val exception = shouldThrow { responses.collect() } + exception.response.promptFeedback?.blockReason shouldBe BlockReason.SAFETY + } + } + + @Test + fun `empty content`() = + goldenStreamingFile("failure-empty-content.txt") { + val responses = model.generateContentStream("prompt") + + withTimeout(testTimeout) { shouldThrow { responses.collect() } } + } + + @Test + fun `http errors`() = + goldenStreamingFile("failure-http-error.txt", HttpStatusCode.PreconditionFailed) { + val responses = model.generateContentStream("prompt") + + withTimeout(testTimeout) { shouldThrow { responses.collect() } } + } + + @Test + fun `stopped for safety`() = + goldenStreamingFile("failure-finish-reason-safety.txt") { + val responses = model.generateContentStream("prompt") + + withTimeout(testTimeout) { + val exception = shouldThrow { responses.collect() } + exception.response.candidates.first().finishReason shouldBe FinishReason.SAFETY + } + } + + @Test + fun `citation parsed correctly`() = + goldenStreamingFile("success-citations.txt") { + val responses = model.generateContentStream("prompt") + + withTimeout(testTimeout) { + val responseList = responses.toList() + responseList.any { it.candidates.any { it.citationMetadata.isNotEmpty() } } shouldBe true + } + } + + @Test + fun `stopped for recitation`() = + goldenStreamingFile("failure-recitation-no-content.txt") { + val responses = model.generateContentStream("prompt") + + withTimeout(testTimeout) { + val exception = shouldThrow { responses.collect() } + exception.response.candidates.first().finishReason shouldBe FinishReason.RECITATION + } + } + + @Test + fun `image rejected`() = + goldenStreamingFile("failure-image-rejected.txt", HttpStatusCode.BadRequest) { + val responses = model.generateContentStream("prompt") + + withTimeout(testTimeout) { shouldThrow { responses.collect() } } + } + + @Test + fun `unknown model`() = + goldenStreamingFile("failure-unknown-model.txt", HttpStatusCode.NotFound) { + val responses = model.generateContentStream("prompt") + + withTimeout(testTimeout) { shouldThrow { responses.collect() } } + } + + @Test + fun `invalid api key`() = + goldenStreamingFile("failure-api-key.txt", HttpStatusCode.BadRequest) { + val responses = model.generateContentStream("prompt") + + withTimeout(testTimeout) { shouldThrow { responses.collect() } } + } +} diff --git a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/UnarySnapshotTests.kt b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/UnarySnapshotTests.kt new file mode 100644 index 00000000000..f55f71ee34e --- /dev/null +++ b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/UnarySnapshotTests.kt @@ -0,0 +1,295 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.firebase.vertexai + +import com.google.firebase.vertexai.type.BlockReason +import com.google.firebase.vertexai.type.FinishReason +import com.google.firebase.vertexai.type.FunctionCallPart +import com.google.firebase.vertexai.type.HarmCategory +import com.google.firebase.vertexai.type.HarmProbability +import com.google.firebase.vertexai.type.HarmSeverity +import com.google.firebase.vertexai.type.InvalidAPIKeyException +import com.google.firebase.vertexai.type.PromptBlockedException +import com.google.firebase.vertexai.type.ResponseStoppedException +import com.google.firebase.vertexai.type.SerializationException +import com.google.firebase.vertexai.type.ServerException +import com.google.firebase.vertexai.type.ServiceDisabledException +import com.google.firebase.vertexai.type.TextPart +import com.google.firebase.vertexai.type.UnsupportedUserLocationException +import com.google.firebase.vertexai.util.goldenUnaryFile +import com.google.firebase.vertexai.util.shouldNotBeNullOrEmpty +import io.kotest.assertions.throwables.shouldThrow +import io.kotest.matchers.collections.shouldNotBeEmpty +import io.kotest.matchers.nulls.shouldNotBeNull +import io.kotest.matchers.should +import io.kotest.matchers.shouldBe +import io.kotest.matchers.shouldNotBe +import io.kotest.matchers.string.shouldNotBeEmpty +import io.kotest.matchers.types.shouldBeInstanceOf +import io.ktor.http.HttpStatusCode +import kotlin.time.Duration.Companion.seconds +import kotlinx.coroutines.withTimeout +import org.json.JSONArray +import org.junit.Test + +internal class UnarySnapshotTests { + private val testTimeout = 5.seconds + + @Test + fun `short reply`() = + goldenUnaryFile("success-basic-reply-short.json") { + withTimeout(testTimeout) { + val response = model.generateContent("prompt") + + response.candidates.isEmpty() shouldBe false + response.candidates.first().finishReason shouldBe FinishReason.STOP + response.candidates.first().content.parts.isEmpty() shouldBe false + response.candidates.first().safetyRatings.isEmpty() shouldBe false + } + } + + @Test + fun `long reply`() = + goldenUnaryFile("success-basic-reply-long.json") { + withTimeout(testTimeout) { + val response = model.generateContent("prompt") + + response.candidates.isEmpty() shouldBe false + response.candidates.first().finishReason shouldBe FinishReason.STOP + response.candidates.first().content.parts.isEmpty() shouldBe false + response.candidates.first().safetyRatings.isEmpty() shouldBe false + } + } + + @Test + fun `unknown enum`() = + goldenUnaryFile("success-unknown-enum.json") { + withTimeout(testTimeout) { + val response = model.generateContent("prompt") + + response.candidates.isEmpty() shouldBe false + val candidate = response.candidates.first() + candidate.safetyRatings.any { it.category == HarmCategory.UNKNOWN } shouldBe true + } + } + + @Test + fun `safetyRatings including severity`() = + goldenUnaryFile("success-including-severity.json") { + withTimeout(testTimeout) { + val response = model.generateContent("prompt") + + response.candidates.isEmpty() shouldBe false + response.candidates.first().safetyRatings.isEmpty() shouldBe false + response.candidates.first().safetyRatings.all { + it.probability == HarmProbability.NEGLIGIBLE + } shouldBe true + response.candidates.first().safetyRatings.all { + it.severity == HarmSeverity.NEGLIGIBLE + } shouldBe true + response.candidates.first().safetyRatings.all { it.severityScore != null } shouldBe true + } + } + + @Test + fun `prompt blocked for safety`() = + goldenUnaryFile("failure-prompt-blocked-safety.json") { + withTimeout(testTimeout) { + shouldThrow { model.generateContent("prompt") } should + { + it.response.promptFeedback?.blockReason shouldBe BlockReason.SAFETY + } + } + } + + @Test + fun `empty content`() = + goldenUnaryFile("failure-empty-content.json") { + withTimeout(testTimeout) { + shouldThrow { model.generateContent("prompt") } + } + } + + @Test + fun `http error`() = + goldenUnaryFile("failure-http-error.json", HttpStatusCode.PreconditionFailed) { + withTimeout(testTimeout) { shouldThrow { model.generateContent("prompt") } } + } + + @Test + fun `user location error`() = + goldenUnaryFile("failure-unsupported-user-location.json", HttpStatusCode.PreconditionFailed) { + withTimeout(testTimeout) { + shouldThrow { model.generateContent("prompt") } + } + } + + @Test + fun `stopped for safety`() = + goldenUnaryFile("failure-finish-reason-safety.json") { + withTimeout(testTimeout) { + val exception = shouldThrow { model.generateContent("prompt") } + exception.response.candidates.first().finishReason shouldBe FinishReason.SAFETY + } + } + + @Test + fun `citation returns correctly`() = + goldenUnaryFile("success-citations.json") { + withTimeout(testTimeout) { + val response = model.generateContent("prompt") + + response.candidates.isEmpty() shouldBe false + response.candidates.first().citationMetadata.size shouldBe 3 + } + } + + @Test + fun `citation returns correctly with missing license and startIndex`() = + goldenUnaryFile("success-citations-nolicense.json") { + withTimeout(testTimeout) { + val response = model.generateContent("prompt") + + response.candidates.isEmpty() shouldBe false + response.candidates.first().citationMetadata.isEmpty() shouldBe false + // Verify the values in the citation source + with(response.candidates.first().citationMetadata.first()) { + license shouldBe null + startIndex shouldBe 0 + } + } + } + + @Test + fun `response includes usage metadata`() = + goldenUnaryFile("success-usage-metadata.json") { + withTimeout(testTimeout) { + val response = model.generateContent("prompt") + + response.candidates.isEmpty() shouldBe false + response.candidates.first().finishReason shouldBe FinishReason.STOP + response.usageMetadata shouldNotBe null + response.usageMetadata?.totalTokenCount shouldBe 363 + } + } + + @Test + fun `response includes partial usage metadata`() = + goldenUnaryFile("success-partial-usage-metadata.json") { + withTimeout(testTimeout) { + val response = model.generateContent("prompt") + + response.candidates.isEmpty() shouldBe false + response.candidates.first().finishReason shouldBe FinishReason.STOP + response.usageMetadata shouldNotBe null + response.usageMetadata?.promptTokenCount shouldBe 6 + response.usageMetadata?.totalTokenCount shouldBe 0 + } + } + + @Test + fun `properly translates json text`() = + goldenUnaryFile("success-constraint-decoding-json.json") { + val response = model.generateContent("prompt") + + response.candidates.isEmpty() shouldBe false + with(response.candidates.first().content.parts.first().shouldBeInstanceOf()) { + shouldNotBeNull() + val jsonArr = JSONArray(text) + jsonArr.length() shouldBe 3 + for (i in 0 until jsonArr.length()) { + with(jsonArr.getJSONObject(i)) { + shouldNotBeNull() + getString("name").shouldNotBeEmpty() + getJSONArray("colors").length() shouldBe 5 + } + } + } + } + + @Test + fun `invalid response`() = + goldenUnaryFile("failure-invalid-response.json") { + withTimeout(testTimeout) { + shouldThrow { model.generateContent("prompt") } + } + } + + @Test + fun `malformed content`() = + goldenUnaryFile("failure-malformed-content.json") { + withTimeout(testTimeout) { + shouldThrow { model.generateContent("prompt") } + } + } + + @Test + fun `invalid api key`() = + goldenUnaryFile("failure-api-key.json", HttpStatusCode.BadRequest) { + withTimeout(testTimeout) { + shouldThrow { model.generateContent("prompt") } + } + } + + @Test + fun `image rejected`() = + goldenUnaryFile("failure-image-rejected.json", HttpStatusCode.BadRequest) { + withTimeout(testTimeout) { shouldThrow { model.generateContent("prompt") } } + } + + @Test + fun `unknown model`() = + goldenUnaryFile("failure-unknown-model.json", HttpStatusCode.NotFound) { + withTimeout(testTimeout) { shouldThrow { model.generateContent("prompt") } } + } + + @Test + fun `service disabled`() = + goldenUnaryFile("failure-service-disabled.json", HttpStatusCode.Forbidden) { + withTimeout(testTimeout) { + shouldThrow { model.generateContent("prompt") } + } + } + + @Test + fun `function call contains null param`() = + goldenUnaryFile("success-function-call-null.json") { + withTimeout(testTimeout) { + val response = model.generateContent("prompt") + val callPart = (response.candidates.first().content.parts.first() as FunctionCallPart) + + callPart.args["season"] shouldBe null + } + } + + @Test + fun `function call contains json literal`() = + goldenUnaryFile("success-function-call-json-literal.json") { + withTimeout(testTimeout) { + val response = model.generateContent("prompt") + val content = response.candidates.shouldNotBeNullOrEmpty().first().content + val callPart = + content.let { + it.shouldNotBeNull() + it.parts.shouldNotBeEmpty() + it.parts.first().shouldBeInstanceOf() + } + + callPart.args["current"] shouldBe "true" + } + } +} diff --git a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/util/kotlin.kt b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/util/kotlin.kt new file mode 100644 index 00000000000..53d9de032ae --- /dev/null +++ b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/util/kotlin.kt @@ -0,0 +1,35 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.firebase.vertexai.util + +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.runBlocking + +/** + * Runs the given [block] using [runBlocking] on the current thread for side effect. + * + * Using this function is like [runBlocking] with default context (which runs the given block on the + * calling thread) but forces the return type to be `Unit`, which is helpful when implementing + * suspending tests as expression functions: + * ``` + * @Test + * fun myTest() = doBlocking {...} + * ``` + */ +internal fun doBlocking(block: suspend CoroutineScope.() -> Unit) { + runBlocking(block = block) +} diff --git a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/util/tests.kt b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/util/tests.kt new file mode 100644 index 00000000000..80f7e7a3756 --- /dev/null +++ b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/util/tests.kt @@ -0,0 +1,182 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.firebase.vertexai.util + +import com.google.ai.client.generativeai.common.APIController +import com.google.ai.client.generativeai.common.RequestOptions +import com.google.firebase.vertexai.GenerativeModel +import io.kotest.matchers.collections.shouldNotBeEmpty +import io.kotest.matchers.nulls.shouldNotBeNull +import io.ktor.http.HttpStatusCode +import io.ktor.utils.io.ByteChannel +import io.ktor.utils.io.close +import io.ktor.utils.io.writeFully +import java.io.File +import kotlinx.coroutines.launch + +private val TEST_CLIENT_ID = "firebase-vertexai-android/test" + +/** String separator used in SSE communication to signal the end of a message. */ +internal const val SSE_SEPARATOR = "\r\n\r\n" + +/** + * Writes the provided [bytes] to the channel and closes it. + * + * Just a wrapper around [writeFully] that closes the channel after writing is complete. + * + * @param bytes the data to send through the channel + */ +internal suspend fun ByteChannel.send(bytes: ByteArray) { + writeFully(bytes) + close() +} + +/** + * Wrapper around common instances needed in tests. + * + * @param channel A [ByteChannel] for sending responses through the mock HTTP engine + * @param apiController A [APIController] that consumes the [channel] + * @see commonTest + * @see send + */ +internal data class CommonTestScope(val channel: ByteChannel, val model: GenerativeModel) + +/** A test that runs under a [CommonTestScope]. */ +internal typealias CommonTest = suspend CommonTestScope.() -> Unit + +/** + * Common test block for providing a [CommonTestScope] during tests. + * + * Example usage: + * ``` + * @Test + * fun `(generateContent) generates a proper response`() = commonTest { + * val request = createRequest("say something nice") + * val response = createResponse("The world is a beautiful place!") + * + * channel.send(prepareResponse(response)) + * + * withTimeout(testTimeout) { + * val data = controller.generateContent(request) + * data.candidates.shouldNotBeEmpty() + * } + * } + * ``` + * + * @param status An optional [HttpStatusCode] to return as a response + * @param requestOptions Optional [RequestOptions] to utilize in the underlying controller + * @param block The test contents themselves, with the [CommonTestScope] implicitly provided + * @see CommonTestScope + */ +internal fun commonTest( + status: HttpStatusCode = HttpStatusCode.OK, + requestOptions: RequestOptions = RequestOptions(), + block: CommonTest, +) = doBlocking { + val channel = ByteChannel(autoFlush = true) + val apiController = + APIController( + "super_cool_test_key", + "gemini-pro", + requestOptions, + TEST_CLIENT_ID, + null, + channel, + status, + ) + val model = GenerativeModel("cool-model-name", controller = apiController) + CommonTestScope(channel, model).block() +} + +/** + * A variant of [commonTest] for performing *streaming-based* snapshot tests. + * + * Loads the *Golden File* and automatically parses the messages from it; providing it to the + * channel. + * + * @param name The name of the *Golden File* to load + * @param httpStatusCode An optional [HttpStatusCode] to return as a response + * @param block The test contents themselves, with a [CommonTestScope] implicitly provided + * @see goldenUnaryFile + */ +internal fun goldenStreamingFile( + name: String, + httpStatusCode: HttpStatusCode = HttpStatusCode.OK, + block: CommonTest, +) = doBlocking { + val goldenFile = loadGoldenFile("streaming-$name") + val messages = goldenFile.readLines().filter { it.isNotBlank() } + + commonTest(httpStatusCode) { + launch { + for (message in messages) { + channel.writeFully("$message$SSE_SEPARATOR".toByteArray()) + } + channel.close() + } + + block() + } +} + +/** + * A variant of [commonTest] for performing snapshot tests. + * + * Loads the *Golden File* and automatically provides it to the channel. + * + * @param name The name of the *Golden File* to load + * @param httpStatusCode An optional [HttpStatusCode] to return as a response + * @param block The test contents themselves, with a [CommonTestScope] implicitly provided + * @see goldenStreamingFile + */ +internal fun goldenUnaryFile( + name: String, + httpStatusCode: HttpStatusCode = HttpStatusCode.OK, + block: CommonTest, +) = + commonTest(httpStatusCode) { + val goldenFile = loadGoldenFile("unary-$name") + val message = goldenFile.readText() + + channel.send(message.toByteArray()) + + block() + } + +/** + * Loads a *Golden File* from the resource directory. + * + * Expects golden files to live under `golden-files` in the resource files. + * + * @see goldenUnaryFile + */ +internal fun loadGoldenFile(path: String): File = + loadResourceFile("vertexai-sdk-test-data/mock-responses/$path") + +/** Loads a file from the test resources directory. */ +internal fun loadResourceFile(path: String) = File("src/test/resources/$path") + +/** + * Ensures that a collection is neither null or empty. + * + * Syntax sugar for [shouldNotBeNull] and [shouldNotBeEmpty]. + */ +inline fun Collection?.shouldNotBeNullOrEmpty(): Collection { + shouldNotBeNull() + shouldNotBeEmpty() + return this +} diff --git a/firebase-vertexai/src/test/resources/README.md b/firebase-vertexai/src/test/resources/README.md new file mode 100644 index 00000000000..cba0fad1813 --- /dev/null +++ b/firebase-vertexai/src/test/resources/README.md @@ -0,0 +1,2 @@ +Mock response files should be cloned into this directory to run unit tests. See +the Vertex AI for Firebase [README](../../..#running-tests) for instructions. \ No newline at end of file diff --git a/firebase-vertexai/update_responses.sh b/firebase-vertexai/update_responses.sh new file mode 100755 index 00000000000..7f8fa38a73e --- /dev/null +++ b/firebase-vertexai/update_responses.sh @@ -0,0 +1,22 @@ +#!/bin/bash + +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script replaces mock response files for Vertex AI unit tests with a fresh +# clone of the shared repository of Vertex AI test data. + +cd "$(dirname "$0")/src/test/resources" || exit +rm -rf vertexai-sdk-test-data +git clone --depth 1 https://github.com/FirebaseExtended/vertexai-sdk-test-data.git