Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Vertex AI snippets to match the iOS ones #553

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ public class ChatViewModel extends ViewModel {
private GenerativeModelFutures model;

void startChatSendMessageStream() {
// [START vertexai_send_message_stream]
// [START chat_streaming]
// (optional) Create previous chat history for context
Content.Builder userContentBuilder = new Content.Builder();
userContentBuilder.setRole("user");
Expand Down Expand Up @@ -84,11 +84,11 @@ public void onError(Throwable t) {
}
// [END_EXCLUDE]
});
// [END vertexai_send_message_stream]
// [END chat_streaming]
}

void startChatSendMessage(Executor executor) {
// [START vertexai_send_message]
// [START chat]
// (optional) Create previous chat history for context
Content.Builder userContentBuilder = new Content.Builder();
userContentBuilder.setRole("user");
Expand Down Expand Up @@ -126,12 +126,12 @@ public void onFailure(@NonNull Throwable t) {
t.printStackTrace();
}
}, executor);
// [END vertexai_send_message]
// [END chat]
}

void countTokensChat(Executor executor) {
ChatFutures chat = model.startChat();
// [START vertexai_count_tokens_chat]
// [START count_tokens_chat]
List<Content> history = chat.getChat().getHistory();

Content messageContent = new Content.Builder()
Expand All @@ -156,24 +156,26 @@ public void onFailure(@NonNull Throwable t) {
t.printStackTrace();
}
}, executor);
// [END vertexai_count_tokens_chat]
// [END count_tokens_chat]
}

void systemInstructionsText() {
// [START vertexai_si_text]
// [START system_instructions_text]
// Initialize the Vertex AI service and the generative model
// Specify a model that supports system instructions, like a Gemini 1.5 model
Content systemInstruction = new Content.Builder()
.addText("You are a cat. Your name is Neko.")
.build();
GenerativeModel model = FirebaseVertexAI.getInstance()
.generativeModel(
/* modelName */ "gemini-1.5-pro-preview-0409",
/* modelName */ "gemini-1.5-flash",
/* generationConfig (optional) */ null,
/* safetySettings (optional) */ null,
/* requestOptions (optional) */ new RequestOptions(),
/* tools (optional) */ null,
/* toolsConfig (optional) */ null,
/* systemInstruction (optional) */ systemInstruction
);
// [END vertexai_si_text]
// [END system_instructions_text]
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
public class ConfigurationViewModel extends ViewModel {

void configModelParams() {
// [START vertexai_model_params]
// [START configure_model]
GenerationConfig.Builder configBuilder = new GenerationConfig.Builder();
configBuilder.temperature = 0.9f;
configBuilder.topK = 16;
Expand All @@ -29,40 +29,44 @@ void configModelParams() {
GenerationConfig generationConfig = configBuilder.build();

GenerativeModel gm = FirebaseVertexAI.Companion.getInstance().generativeModel(
"MODEL_NAME",
"gemini-1.5-flash",
generationConfig
);

GenerativeModelFutures model = GenerativeModelFutures.from(gm);
// [END vertexai_model_params]
// [END configure_model]
}

void configSafetySettings() {
SafetySetting harassmentSafety1 = new SafetySetting(HarmCategory.HARASSMENT,
// [START safety_settings]
SafetySetting harassmentSafety = new SafetySetting(HarmCategory.HARASSMENT,
BlockThreshold.ONLY_HIGH);

GenerativeModel gm1 = FirebaseVertexAI.Companion.getInstance().generativeModel(
"MODEL_NAME",
GenerativeModel gm = FirebaseVertexAI.Companion.getInstance().generativeModel(
"gemini-1.5-flash",
/* generationConfig is optional */ null,
Collections.singletonList(harassmentSafety1)
Collections.singletonList(harassmentSafety)
);

GenerativeModelFutures model1 = GenerativeModelFutures.from(gm1);
GenerativeModelFutures model = GenerativeModelFutures.from(gm);
// [END safety_settings]
}

// [START vertexai_safety_settings]
void configMultiSafetySettings() {
// [START multi_safety_settings]
SafetySetting harassmentSafety = new SafetySetting(HarmCategory.HARASSMENT,
BlockThreshold.ONLY_HIGH);

SafetySetting hateSpeechSafety = new SafetySetting(HarmCategory.HATE_SPEECH,
BlockThreshold.MEDIUM_AND_ABOVE);

GenerativeModel gm = FirebaseVertexAI.Companion.getInstance().generativeModel(
"MODEL_NAME",
"gemini-1.5-flash",
/* generationConfig is optional */ null,
List.of(harassmentSafety, hateSpeechSafety)
);

GenerativeModelFutures model = GenerativeModelFutures.from(gm);
// [END vertexai_safety_settings]
// [END multi_safety_settings]
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,16 @@ public class GenerateContentViewModel extends ViewModel {
// Only meant to separate the scope of the initialization snippet
// so that it doesn't cause a naming clash with the top level declaration
static class InitializationSnippet {
// [START vertexai_init]
// [START initialize_model]
GenerativeModel gm = FirebaseVertexAI.getInstance()
.generativeModel("gemini-1.5-pro-preview-0409");
.generativeModel("gemini-1.5-flash");

GenerativeModelFutures model = GenerativeModelFutures.from(gm);
// [END vertexai_init]
// [END initialize_model]
}

void generateContentStream() {
// [START vertexai_textonly_stream]
// [START text_gen_text_only_prompt_streaming]
Content prompt = new Content.Builder()
.addText("Write a story about a magic backpack.")
.build();
Expand Down Expand Up @@ -77,11 +77,11 @@ public void onError(Throwable t) {
public void onSubscribe(Subscription s) {
}
});
// [END vertexai_textonly_stream]
// [END text_gen_text_only_prompt_streaming]
}

void generateContent(Executor executor) {
// [START vertexai_textonly]
// [START text_gen_text_only_prompt]
// Provide a prompt that contains text
Content prompt = new Content.Builder()
.addText("Write a story about a magic backpack.")
Expand All @@ -101,7 +101,7 @@ public void onFailure(Throwable t) {
t.printStackTrace();
}
}, executor);
// [END vertexai_textonly]
// [END text_gen_text_only_prompt]
}

// Fake implementation to exemplify Activity.getResources()
Expand All @@ -115,7 +115,7 @@ Context getApplicationContext() {
}

void generateContentWithImageStream() {
// [START vertexai_text_and_image_stream]
// [START text_gen_multimodal_one_image_prompt_streaming]
Bitmap bitmap = BitmapFactory.decodeResource(getResources(), R.drawable.sparky);

Content prompt = new Content.Builder()
Expand Down Expand Up @@ -149,11 +149,11 @@ public void onError(Throwable t) {
public void onSubscribe(Subscription s) {
}
});
// [END vertexai_text_and_image_stream]
// [END text_gen_multimodal_one_image_prompt_streaming]
}

void generateContentWithImage(Executor executor) {
// [START vertexai_text_and_image]
// [START text_gen_multimodal_one_image_prompt]
Bitmap bitmap = BitmapFactory.decodeResource(getResources(), R.drawable.sparky);

Content content = new Content.Builder()
Expand All @@ -174,11 +174,11 @@ public void onFailure(Throwable t) {
t.printStackTrace();
}
}, executor);
// [END vertexai_text_and_image]
// [END text_gen_multimodal_one_image_prompt]
}

void generateContentWithMultipleImagesStream() {
// [START vertexai_text_and_images_stream]
// [START text_gen_multimodal_multi_image_prompt_streaming]
Bitmap bitmap1 = BitmapFactory.decodeResource(getResources(), R.drawable.sparky);
Bitmap bitmap2 = BitmapFactory.decodeResource(getResources(), R.drawable.sparky_eats_pizza);

Expand Down Expand Up @@ -215,11 +215,11 @@ public void onError(Throwable t) {
public void onSubscribe(Subscription s) {
}
});
// [END vertexai_text_and_images_stream]
// [END text_gen_multimodal_multi_image_prompt_streaming]
}

void generateContentWithMultipleImages(Executor executor) {
// [START vertexai_text_and_images]
// [START text_gen_multimodal_multi_image_prompt]
Bitmap bitmap1 = BitmapFactory.decodeResource(getResources(), R.drawable.sparky);
Bitmap bitmap2 = BitmapFactory.decodeResource(getResources(), R.drawable.sparky_eats_pizza);

Expand All @@ -243,11 +243,11 @@ public void onFailure(Throwable t) {
t.printStackTrace();
}
}, executor);
// [END vertexai_text_and_images]
// [END text_gen_multimodal_multi_image_prompt]
}

void generateContentWithVideo(Executor executor, Uri videoUri) {
// [START vertexai_text_and_video]
// [START text_gen_multimodal_video_prompt]
ContentResolver resolver = getApplicationContext().getContentResolver();
try (InputStream stream = resolver.openInputStream(videoUri)) {
File videoFile = new File(new URI(videoUri.toString()));
Expand Down Expand Up @@ -281,13 +281,13 @@ public void onFailure(Throwable t) {
} catch (URISyntaxException e) {
e.printStackTrace();
}
// [END vertexai_text_and_video]
// [END text_gen_multimodal_video_prompt]
}

void generateContentWithVideoStream(
Uri videoUri
) {
// [START vertexai_text_and_video_stream]
// [START text_gen_multimodal_video_prompt_streaming]
ContentResolver resolver = getApplicationContext().getContentResolver();
try (InputStream stream = resolver.openInputStream(videoUri)) {
File videoFile = new File(new URI(videoUri.toString()));
Expand Down Expand Up @@ -334,16 +334,17 @@ public void onSubscribe(Subscription s) {
} catch (URISyntaxException e) {
e.printStackTrace();
}
// [END vertexai_text_and_video_stream]
// [END text_gen_multimodal_video_prompt_streaming]
}

void countTokensText(Executor executor) {
// [START vertexai_count_tokens_text]
Content text = new Content.Builder()
// [START count_tokens_text]
Content prompt = new Content.Builder()
.addText("Write a story about a magic backpack.")
.build();

ListenableFuture<CountTokensResponse> countTokensResponse = model.countTokens(text);
// Count tokens and billable characters before calling generateContent
ListenableFuture<CountTokensResponse> countTokensResponse = model.countTokens(prompt);

Futures.addCallback(countTokensResponse, new FutureCallback<CountTokensResponse>() {
@Override
Expand All @@ -352,25 +353,28 @@ public void onSuccess(CountTokensResponse result) {
int totalBillableTokens = result.getTotalBillableCharacters();
System.out.println("totalTokens = " + totalTokens +
"totalBillableTokens = " + totalBillableTokens);

// To generate text output, call generateContent with the text input
ListenableFuture<GenerateContentResponse> response = model.generateContent(prompt);
}

@Override
public void onFailure(Throwable t) {
t.printStackTrace();
}
}, executor);
// [END vertexai_count_tokens_text]
// [END count_tokens_text]
}

void countTokensMultimodal(Executor executor, Bitmap bitmap) {
// [START vertexai_count_tokens_multimodal]
Content text = new Content.Builder()
// [START count_tokens_text_image]
Content prompt = new Content.Builder()
.addImage(bitmap)
.addText("Where can I buy this")
.build();

// For text-only input
ListenableFuture<CountTokensResponse> countTokensResponse = model.countTokens(text);
// Count tokens and billable characters before calling generateContent
ListenableFuture<CountTokensResponse> countTokensResponse = model.countTokens(prompt);

Futures.addCallback(countTokensResponse, new FutureCallback<CountTokensResponse>() {
@Override
Expand All @@ -379,13 +383,16 @@ public void onSuccess(CountTokensResponse result) {
int totalBillableTokens = result.getTotalBillableCharacters();
System.out.println("totalTokens = " + totalTokens +
"totalBillableTokens = " + totalBillableTokens);

// To generate text output, call generateContent with the prompt
ListenableFuture<GenerateContentResponse> response = model.generateContent(prompt);
}

@Override
public void onFailure(Throwable t) {
t.printStackTrace();
}
}, executor);
// [END vertexai_count_tokens_multimodal]
// [END count_tokens_text_image]
}
}
Loading
Loading