Skip to content

Commit

Permalink
feat: GenAI - Added support for system instructions
Browse files Browse the repository at this point in the history
Usage:
```
model = generative_models.GenerativeModel(
    "gemini-1.0-pro",
    system_instruction=[
        "Talk like a pirate.",
        "Don't use rude words.",
    ],
)
```
PiperOrigin-RevId: 621703355
  • Loading branch information
Ark-kun authored and copybara-github committed Apr 4, 2024
1 parent d0585e8 commit 4990eb6
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 2 deletions.
8 changes: 7 additions & 1 deletion tests/system/vertexai/test_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,13 @@ async def test_generate_content_streaming_async(self):
assert chunk.text

def test_generate_content_with_parameters(self):
model = generative_models.GenerativeModel("gemini-pro")
model = generative_models.GenerativeModel(
"gemini-pro",
system_instruction=[
"Talk like a pirate.",
"Don't use rude words.",
],
)
response = model.generate_content(
contents="Why is sky blue?",
generation_config=generative_models.GenerationConfig(
Expand Down
9 changes: 8 additions & 1 deletion tests/unit/vertexai/test_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,14 @@ def test_generate_content(self, generative_models: generative_models):
response = model.generate_content("Why is sky blue?")
assert response.text

response2 = model.generate_content(
model2 = generative_models.GenerativeModel(
"gemini-pro",
system_instruction=[
"Talk like a pirate.",
"Don't use rude words.",
],
)
response2 = model2.generate_content(
"Why is sky blue?",
generation_config=generative_models.GenerationConfig(
temperature=0.2,
Expand Down
13 changes: 13 additions & 0 deletions vertexai/generative_models/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,19 @@ print(vision_chat.send_message(["I like this image.", image]))
print(vision_chat.send_message("What things do I like?."))
```

#### System instructions
```
from vertexai.generative_models import GenerativeModel
model = GenerativeModel(
"gemini-1.0-pro",
system_instruction=[
"Talk like a pirate.",
"Don't use rude words.",
],
)
print(model.generate_content("Why is sky blue?"))
```

#### Function calling

```
Expand Down
13 changes: 13 additions & 0 deletions vertexai/generative_models/_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def __init__(
generation_config: Optional[GenerationConfigType] = None,
safety_settings: Optional[SafetySettingsType] = None,
tools: Optional[List["Tool"]] = None,
system_instruction: Optional[PartsType] = None,
):
r"""Initializes GenerativeModel.
Expand All @@ -147,6 +148,9 @@ def __init__(
generation_config: Default generation config to use in generate_content.
safety_settings: Default safety settings to use in generate_content.
tools: Default tools to use in generate_content.
system_instruction: Default system instruction to use in generate_content.
Note: Only text should be used in parts.
Content of each part will become a separate paragraph.
"""
if "/" not in model_name:
model_name = "publishers/google/models/" + model_name
Expand All @@ -163,13 +167,15 @@ def __init__(
self._generation_config = generation_config
self._safety_settings = safety_settings
self._tools = tools
self._system_instruction = system_instruction

# Validating the parameters
self._prepare_request(
contents="test",
generation_config=generation_config,
safety_settings=safety_settings,
tools=tools,
system_instruction=system_instruction,
)

@property
Expand Down Expand Up @@ -205,6 +211,7 @@ def _prepare_request(
generation_config: Optional[GenerationConfigType] = None,
safety_settings: Optional[SafetySettingsType] = None,
tools: Optional[List["Tool"]] = None,
system_instruction: Optional[PartsType] = None,
) -> gapic_prediction_service_types.GenerateContentRequest:
"""Prepares a GAPIC GenerateContentRequest."""
if not contents:
Expand All @@ -213,6 +220,7 @@ def _prepare_request(
generation_config = generation_config or self._generation_config
safety_settings = safety_settings or self._safety_settings
tools = tools or self._tools
system_instruction = system_instruction or self._system_instruction

# contents can either be a list of Content objects (most generic case)
if isinstance(contents, Sequence) and any(
Expand Down Expand Up @@ -244,6 +252,10 @@ def _prepare_request(
else:
contents = [_to_content(contents)]

gapic_system_instruction: Optional[gapic_content_types.Content] = None
if system_instruction:
gapic_system_instruction = _to_content(system_instruction)

gapic_generation_config: Optional[gapic_content_types.GenerationConfig] = None
if generation_config:
if isinstance(generation_config, gapic_content_types.GenerationConfig):
Expand Down Expand Up @@ -307,6 +319,7 @@ def _prepare_request(
generation_config=gapic_generation_config,
safety_settings=gapic_safety_settings,
tools=gapic_tools,
system_instruction=gapic_system_instruction,
)

def _parse_response(
Expand Down

0 comments on commit 4990eb6

Please sign in to comment.