Skip to content

Commit

Permalink
Add Vertex AI unit tests (#6090)
Browse files Browse the repository at this point in the history
  • Loading branch information
tanzimfh committed Jul 15, 2024
1 parent d26a5f8 commit 4db5dc8
Show file tree
Hide file tree
Showing 11 changed files with 747 additions and 2 deletions.
5 changes: 5 additions & 0 deletions .github/workflows/ci_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ jobs:
./gradlew :common:updateVersion common:publishToMavenLocal
cd ..
- name: Clone mock responses
if: matrix.module == ':firebase-vertexai'
run: |
firebase-vertexai/update_responses.sh
- name: Add google-services.json
env:
INTEG_TESTS_GOOGLE_SERVICES: ${{ secrets.INTEG_TESTS_GOOGLE_SERVICES }}
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ smoke-test-logs/
smoke-tests/build-debug-headGit-smoke-test
smoke-tests/firehorn.log
macrobenchmark-output.json
vertexai-sdk-test-data/

# generated Terraform docs
.terraform/*
Expand Down
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,11 @@ Unit tests can be executed on the command line by running
./gradlew :<firebase-project>:check
```

#### Vertex AI for Firebase

See the Vertex AI for Firebase [README](firebase-vertexai#running-tests) for setup
instructions specific to that project.

### Integration Testing

These are tests that run on a hardware device or emulator. These tests have
Expand Down
4 changes: 4 additions & 0 deletions firebase-vertexai/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ All Gradle commands should be run from the root of this repository.

## Running Tests

> [!IMPORTANT]
> These unit tests require mock response files, which can be downloaded by running
`./firebase-vertexai/update_responses.sh` from the root of this repository.

Unit tests:

`./gradlew :firebase-vertexai:check`
Expand Down
11 changes: 9 additions & 2 deletions firebase-vertexai/firebase-vertexai.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,10 @@ android {
targetCompatibility = JavaVersion.VERSION_1_8
}
kotlinOptions { jvmTarget = "1.8" }
testOptions.unitTests.isIncludeAndroidResources = true
testOptions {
unitTests.isIncludeAndroidResources = true
unitTests.isReturnDefaultValues = true
}
}

dependencies {
Expand All @@ -58,7 +61,7 @@ dependencies {
implementation("com.google.firebase:firebase-components:18.0.0")
implementation("com.google.firebase:firebase-annotations:16.2.0")
implementation("com.google.firebase:firebase-appcheck-interop:17.1.0")
implementation("com.google.ai.client.generativeai:common:0.7.1")
implementation("com.google.ai.client.generativeai:common:0.9.0")
implementation(libs.androidx.annotation)
implementation("org.jetbrains.kotlinx:kotlinx-serialization-json:1.5.1")
implementation("androidx.core:core-ktx:1.12.0")
Expand All @@ -71,8 +74,12 @@ dependencies {
implementation("androidx.concurrent:concurrent-futures-ktx:1.2.0-alpha03")
implementation("com.google.firebase:firebase-auth-interop:18.0.0")

val ktorVersion = "2.3.2"
testImplementation("io.kotest:kotest-assertions-core:5.5.5")
testImplementation("io.kotest:kotest-assertions-core-jvm:5.5.5")
testImplementation("io.ktor:ktor-client-okhttp:$ktorVersion")
testImplementation("io.ktor:ktor-client-mock:$ktorVersion")
testImplementation("org.json:json:20240303")
testImplementation(libs.androidx.test.junit)
testImplementation(libs.androidx.test.runner)
testImplementation(libs.junit)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
/*
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.firebase.vertexai

import com.google.firebase.vertexai.type.BlockReason
import com.google.firebase.vertexai.type.FinishReason
import com.google.firebase.vertexai.type.HarmCategory
import com.google.firebase.vertexai.type.InvalidAPIKeyException
import com.google.firebase.vertexai.type.PromptBlockedException
import com.google.firebase.vertexai.type.ResponseStoppedException
import com.google.firebase.vertexai.type.SerializationException
import com.google.firebase.vertexai.type.ServerException
import com.google.firebase.vertexai.type.TextPart
import com.google.firebase.vertexai.util.goldenStreamingFile
import io.kotest.assertions.throwables.shouldThrow
import io.kotest.matchers.nulls.shouldNotBeNull
import io.kotest.matchers.shouldBe
import io.kotest.matchers.string.shouldContain
import io.ktor.http.HttpStatusCode
import kotlin.time.Duration.Companion.seconds
import kotlinx.coroutines.flow.collect
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.flow.toList
import kotlinx.coroutines.withTimeout
import org.junit.Test

internal class StreamingSnapshotTests {
private val testTimeout = 5.seconds

@Test
fun `short reply`() =
goldenStreamingFile("success-basic-reply-short.txt") {
val responses = model.generateContentStream("prompt")

withTimeout(testTimeout) {
val responseList = responses.toList()
responseList.isEmpty() shouldBe false
responseList.first().candidates.first().finishReason shouldBe FinishReason.STOP
responseList.first().candidates.first().content.parts.isEmpty() shouldBe false
responseList.first().candidates.first().safetyRatings.isEmpty() shouldBe false
}
}

@Test
fun `long reply`() =
goldenStreamingFile("success-basic-reply-long.txt") {
val responses = model.generateContentStream("prompt")

withTimeout(testTimeout) {
val responseList = responses.toList()
responseList.isEmpty() shouldBe false
responseList.forEach {
it.candidates.first().finishReason shouldBe FinishReason.STOP
it.candidates.first().content.parts.isEmpty() shouldBe false
it.candidates.first().safetyRatings.isEmpty() shouldBe false
}
}
}

@Test
fun `unknown enum`() =
goldenStreamingFile("success-unknown-enum.txt") {
val responses = model.generateContentStream("prompt")

withTimeout(testTimeout) {
val responseList = responses.toList()

responseList.isEmpty() shouldBe false
responseList.any {
it.candidates.any { it.safetyRatings.any { it.category == HarmCategory.UNKNOWN } }
} shouldBe true
}
}

@Test
fun `quotes escaped`() =
goldenStreamingFile("success-quotes-escaped.txt") {
val responses = model.generateContentStream("prompt")

withTimeout(testTimeout) {
val responseList = responses.toList()

responseList.isEmpty() shouldBe false
val part = responseList.first().candidates.first().content.parts.first() as? TextPart
part.shouldNotBeNull()
part.text shouldContain "\""
}
}

@Test
fun `prompt blocked for safety`() =
goldenStreamingFile("failure-prompt-blocked-safety.txt") {
val responses = model.generateContentStream("prompt")

withTimeout(testTimeout) {
val exception = shouldThrow<PromptBlockedException> { responses.collect() }
exception.response.promptFeedback?.blockReason shouldBe BlockReason.SAFETY
}
}

@Test
fun `empty content`() =
goldenStreamingFile("failure-empty-content.txt") {
val responses = model.generateContentStream("prompt")

withTimeout(testTimeout) { shouldThrow<SerializationException> { responses.collect() } }
}

@Test
fun `http errors`() =
goldenStreamingFile("failure-http-error.txt", HttpStatusCode.PreconditionFailed) {
val responses = model.generateContentStream("prompt")

withTimeout(testTimeout) { shouldThrow<ServerException> { responses.collect() } }
}

@Test
fun `stopped for safety`() =
goldenStreamingFile("failure-finish-reason-safety.txt") {
val responses = model.generateContentStream("prompt")

withTimeout(testTimeout) {
val exception = shouldThrow<ResponseStoppedException> { responses.collect() }
exception.response.candidates.first().finishReason shouldBe FinishReason.SAFETY
}
}

@Test
fun `citation parsed correctly`() =
goldenStreamingFile("success-citations.txt") {
val responses = model.generateContentStream("prompt")

withTimeout(testTimeout) {
val responseList = responses.toList()
responseList.any { it.candidates.any { it.citationMetadata.isNotEmpty() } } shouldBe true
}
}

@Test
fun `stopped for recitation`() =
goldenStreamingFile("failure-recitation-no-content.txt") {
val responses = model.generateContentStream("prompt")

withTimeout(testTimeout) {
val exception = shouldThrow<ResponseStoppedException> { responses.collect() }
exception.response.candidates.first().finishReason shouldBe FinishReason.RECITATION
}
}

@Test
fun `image rejected`() =
goldenStreamingFile("failure-image-rejected.txt", HttpStatusCode.BadRequest) {
val responses = model.generateContentStream("prompt")

withTimeout(testTimeout) { shouldThrow<ServerException> { responses.collect() } }
}

@Test
fun `unknown model`() =
goldenStreamingFile("failure-unknown-model.txt", HttpStatusCode.NotFound) {
val responses = model.generateContentStream("prompt")

withTimeout(testTimeout) { shouldThrow<ServerException> { responses.collect() } }
}

@Test
fun `invalid api key`() =
goldenStreamingFile("failure-api-key.txt", HttpStatusCode.BadRequest) {
val responses = model.generateContentStream("prompt")

withTimeout(testTimeout) { shouldThrow<InvalidAPIKeyException> { responses.collect() } }
}
}
Loading

0 comments on commit 4db5dc8

Please sign in to comment.