Skip to content

Commit

Permalink
Refactor ai parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
takahirom committed Nov 3, 2024
1 parent ca6230e commit 658bb29
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ INPUT_PROMPT
actualImageFilePath: String,
aiAssertionOptions: AiAssertionOptions
): AiAssertionResults
companion object {
const val DefaultMaxOutputTokens = 300
const val DefaultTemperature = 0.4F
}
}

data class AiAssertion(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package com.github.takahirom.roborazzi

import com.github.takahirom.roborazzi.AiAssertionOptions.AiAssertionModel
import com.github.takahirom.roborazzi.AiAssertionOptions.AiAssertionModel.Companion.DefaultMaxOutputTokens
import com.github.takahirom.roborazzi.AiAssertionOptions.AiAssertionModel.Companion.DefaultTemperature
import dev.shreyaspatil.ai.client.generativeai.GenerativeModel
import dev.shreyaspatil.ai.client.generativeai.type.FunctionType
import dev.shreyaspatil.ai.client.generativeai.type.GenerationConfig
Expand All @@ -15,9 +17,10 @@ import kotlinx.serialization.Serializable
@ExperimentalRoborazziApi
class GeminiAiAssertionModel(
private val apiKey: String,
private val modelName: String = "gemini-1.5-pro",
private val modelName: String = "gemini-1.5-flash",
private val generationConfigBuilder: GenerationConfig.Builder.() -> Unit = {
maxOutputTokens = 8192
maxOutputTokens = DefaultMaxOutputTokens
temperature = DefaultTemperature
}
) : AiAssertionModel {
override fun assert(
Expand Down
2 changes: 1 addition & 1 deletion roborazzi-ai-openai/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ kotlin {
implementation "io.github.takahirom.roborazzi:roborazzi-core:$VERSION_NAME"
implementation(libs.kotlinx.coroutines.core)
implementation(libs.ktor.serialization.json)
api(libs.ktor.client.core)
implementation(libs.ktor.client.core)
implementation(libs.ktor.client.cio)
implementation(libs.ktor.client.logging)
implementation(libs.ktor.client.contentnegotiation)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package com.github.takahirom.roborazzi

import com.github.takahirom.roborazzi.AiAssertionOptions.AiAssertionModel.Companion.DefaultMaxOutputTokens
import com.github.takahirom.roborazzi.AiAssertionOptions.AiAssertionModel.Companion.DefaultTemperature
import com.github.takahirom.roborazzi.CaptureResults.Companion.json
import io.ktor.client.HttpClient
import io.ktor.client.plugins.contentnegotiation.ContentNegotiation
Expand Down Expand Up @@ -36,8 +38,8 @@ class OpenAiAiAssertionModel(
private val modelName: String = "gpt-4o",
private val baseUrl: String = "https://api.openai.com/v1/",
private val loggingEnabled: Boolean = false,
private val temperature: Float = 0.4F,
private val maxTokens: Int = 300,
private val temperature: Float = DefaultTemperature,
private val maxTokens: Int = DefaultMaxOutputTokens,
private val seed: Int = 1566,
private val requestBuilderModifier: (HttpRequestBuilder.() -> Unit) = {
header("Authorization", "Bearer $apiKey")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,22 @@ import androidx.test.espresso.Espresso.onView
import androidx.test.espresso.matcher.ViewMatchers
import androidx.test.ext.junit.runners.AndroidJUnit4
import com.github.takahirom.roborazzi.AiAssertionOptions
import com.github.takahirom.roborazzi.DEFAULT_ROBORAZZI_OUTPUT_DIR_PATH
import com.github.takahirom.roborazzi.GeminiAiAssertionModel
import com.github.takahirom.roborazzi.ROBORAZZI_DEBUG
import com.github.takahirom.roborazzi.RobolectricDeviceQualifiers
import com.github.takahirom.roborazzi.RoborazziOptions
import com.github.takahirom.roborazzi.RoborazziRule
import com.github.takahirom.roborazzi.RoborazziTaskType
import com.github.takahirom.roborazzi.captureRoboImage
import com.github.takahirom.roborazzi.provideRoborazziContext
import com.github.takahirom.roborazzi.roboOutputName
import org.junit.Rule
import org.junit.Test
import org.junit.runner.RunWith
import org.robolectric.annotation.Config
import org.robolectric.annotation.GraphicsMode
import java.io.File

@RunWith(AndroidJUnit4::class)
@GraphicsMode(GraphicsMode.Mode.NATIVE)
Expand All @@ -32,6 +36,7 @@ class GeminiAiAiTest {
val roborazziRule = RoborazziRule(
options = RoborazziRule.Options(
roborazziOptions = RoborazziOptions(
taskType = RoborazziTaskType.Compare,
compareOptions = RoborazziOptions.CompareOptions(
aiAssertionOptions = AiAssertionOptions(
aiAssertionModel = GeminiAiAssertionModel(
Expand All @@ -50,6 +55,7 @@ class GeminiAiAiTest {
println("Skip the test because gemini_api_key is not set.")
return
}
File(DEFAULT_ROBORAZZI_OUTPUT_DIR_PATH + File.separator + roboOutputName() + ".png").delete()
onView(ViewMatchers.isRoot())
.captureRoboImage(
roborazziOptions = provideRoborazziContext().options.addedAiAssertions(
Expand All @@ -63,5 +69,6 @@ class GeminiAiAiTest {
)
)
)
File(DEFAULT_ROBORAZZI_OUTPUT_DIR_PATH + File.separator + roboOutputName() + "_compare.png").delete()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,22 @@ import androidx.test.espresso.Espresso.onView
import androidx.test.espresso.matcher.ViewMatchers
import androidx.test.ext.junit.runners.AndroidJUnit4
import com.github.takahirom.roborazzi.AiAssertionOptions
import com.github.takahirom.roborazzi.DEFAULT_ROBORAZZI_OUTPUT_DIR_PATH
import com.github.takahirom.roborazzi.OpenAiAiAssertionModel
import com.github.takahirom.roborazzi.ROBORAZZI_DEBUG
import com.github.takahirom.roborazzi.RobolectricDeviceQualifiers
import com.github.takahirom.roborazzi.RoborazziOptions
import com.github.takahirom.roborazzi.RoborazziRule
import com.github.takahirom.roborazzi.RoborazziTaskType
import com.github.takahirom.roborazzi.captureRoboImage
import com.github.takahirom.roborazzi.provideRoborazziContext
import com.github.takahirom.roborazzi.roboOutputName
import org.junit.Rule
import org.junit.Test
import org.junit.runner.RunWith
import org.robolectric.annotation.Config
import org.robolectric.annotation.GraphicsMode
import java.io.File

@RunWith(AndroidJUnit4::class)
@GraphicsMode(GraphicsMode.Mode.NATIVE)
Expand All @@ -32,6 +36,7 @@ class OpenAiTest {
val roborazziRule = RoborazziRule(
options = RoborazziRule.Options(
roborazziOptions = RoborazziOptions(
taskType = RoborazziTaskType.Compare,
compareOptions = RoborazziOptions.CompareOptions(
aiAssertionOptions = AiAssertionOptions(
aiAssertionModel = OpenAiAiAssertionModel(
Expand All @@ -45,12 +50,13 @@ class OpenAiTest {
)

@Test
fun captureWithAi3() {
fun captureWithAi() {
ROBORAZZI_DEBUG = true
if (System.getenv("openai_api_key") == null) {
println("Skip the test because openai_api_key is not set.")
return
}
File(DEFAULT_ROBORAZZI_OUTPUT_DIR_PATH + File.separator + roboOutputName() + ".png").delete()
onView(ViewMatchers.isRoot())
.captureRoboImage(
roborazziOptions = provideRoborazziContext().options.addedAiAssertions(
Expand All @@ -64,5 +70,6 @@ class OpenAiTest {
)
)
)
File(DEFAULT_ROBORAZZI_OUTPUT_DIR_PATH + File.separator + roboOutputName() + "_compare.png").delete()
}
}

0 comments on commit 658bb29

Please sign in to comment.