Skip to content

Commit

Permalink
feat: LVM - Added multi-language support for ImageGenerationModel
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 580285009
  • Loading branch information
vertex-sdk-bot authored and copybara-github committed Nov 7, 2023
1 parent 9c4decc commit 791eff5
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 0 deletions.
6 changes: 6 additions & 0 deletions tests/system/aiplatform/test_vision_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def test_image_generation_model_generate_images(self):
number_of_images = 4
seed = 1
guidance_scale = 15
language = "en"

prompt1 = "Astronaut riding a horse"
negative_prompt1 = "bad quality"
Expand All @@ -110,6 +111,7 @@ def test_image_generation_model_generate_images(self):
# height=height,
seed=seed,
guidance_scale=guidance_scale,
language=language,
)

assert len(image_response.images) == number_of_images
Expand All @@ -125,6 +127,7 @@ def test_image_generation_model_generate_images(self):
assert image.generation_parameters["seed"] == seed
assert image.generation_parameters["guidance_scale"] == guidance_scale
assert image.generation_parameters["index_of_image_in_batch"] == idx
assert image.generation_parameters["language"] == language

# Test saving and loading images
with tempfile.TemporaryDirectory() as temp_dir:
Expand All @@ -134,6 +137,7 @@ def test_image_generation_model_generate_images(self):
# assert image1._pil_image.size == (width, height)
assert image1.generation_parameters
assert image1.generation_parameters["prompt"] == prompt1
assert image1.generation_parameters["language"] == language

# Preparing mask
mask_path = os.path.join(temp_dir, "mask.png")
Expand All @@ -151,6 +155,7 @@ def test_image_generation_model_generate_images(self):
guidance_scale=guidance_scale,
base_image=image1,
mask=mask_image,
language=language,
)
assert len(image_response2.images) == number_of_images
for idx, image in enumerate(image_response2):
Expand All @@ -161,5 +166,6 @@ def test_image_generation_model_generate_images(self):
assert image.generation_parameters["seed"] == seed
assert image.generation_parameters["guidance_scale"] == guidance_scale
assert image.generation_parameters["index_of_image_in_batch"] == idx
assert image.generation_parameters["language"] == language
assert "base_image_hash" in image.generation_parameters
assert "mask_hash" in image.generation_parameters
9 changes: 9 additions & 0 deletions tests/unit/aiplatform/test_vision_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ def test_generate_images(self):
number_of_images = 4
seed = 1
guidance_scale = 15
language = "en"

image_generation_response = make_image_generation_response(
width=width, height=height, count=number_of_images
Expand All @@ -246,6 +247,7 @@ def test_generate_images(self):
# height=height,
seed=seed,
guidance_scale=guidance_scale,
language=language,
)
predict_kwargs = mock_predict.call_args[1]
actual_parameters = predict_kwargs["parameters"]
Expand All @@ -257,6 +259,7 @@ def test_generate_images(self):
# assert actual_parameters["aspectRatio"] == f"{width}:{height}"
assert actual_parameters["seed"] == seed
assert actual_parameters["guidanceScale"] == guidance_scale
assert actual_parameters["language"] == language

assert len(image_response.images) == number_of_images
for idx, image in enumerate(image_response):
Expand All @@ -269,6 +272,7 @@ def test_generate_images(self):
# assert image.generation_parameters["height"] == height
assert image.generation_parameters["seed"] == seed
assert image.generation_parameters["guidance_scale"] == guidance_scale
assert image.generation_parameters["language"] == language
assert image.generation_parameters["index_of_image_in_batch"] == idx
image.show()

Expand All @@ -280,6 +284,7 @@ def test_generate_images(self):
# assert image1._pil_image.size == (width, height)
assert image1.generation_parameters
assert image1.generation_parameters["prompt"] == prompt1
assert image1.generation_parameters["language"] == language

# Preparing mask
mask_path = os.path.join(temp_dir, "mask.png")
Expand All @@ -302,12 +307,15 @@ def test_generate_images(self):
guidance_scale=guidance_scale,
base_image=image1,
mask=mask_image,
language=language,
)
predict_kwargs = mock_predict.call_args[1]
actual_parameters = predict_kwargs["parameters"]
actual_instance = predict_kwargs["instances"][0]
assert actual_instance["prompt"] == prompt2
assert actual_instance["image"]["bytesBase64Encoded"]
assert actual_instance["mask"]["image"]["bytesBase64Encoded"]
assert actual_parameters["language"] == language

assert len(image_response2.images) == number_of_images
for image in image_response2:
Expand All @@ -316,6 +324,7 @@ def test_generate_images(self):
assert image.generation_parameters["prompt"] == prompt2
assert image.generation_parameters["base_image_hash"]
assert image.generation_parameters["mask_hash"]
assert image.generation_parameters["language"] == language

@unittest.skip(reason="b/295946075 The service stopped supporting image sizes.")
def test_generate_images_requests_square_images_by_default(self):
Expand Down
18 changes: 18 additions & 0 deletions vertexai/vision_models/_vision_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def _generate_images(
seed: Optional[int] = None,
base_image: Optional["Image"] = None,
mask: Optional["Image"] = None,
language:Optional[str] = None,
) -> "ImageGenerationResponse":
"""Generates images from text prompt.
Expand All @@ -160,6 +161,9 @@ def _generate_images(
seed: Image generation random seed.
base_image: Base image to use for the image generation.
mask: Mask for the base image.
language: Language of the text prompt for the image. Default: None.
Supported values are `"en"` for English, `"hi"` for Hindi,
`"ja"` for Japanese, `"ko"` for Korean, and `"auto"` for automatic language detection.
Returns:
An `ImageGenerationResponse` object.
Expand Down Expand Up @@ -216,6 +220,10 @@ def _generate_images(
parameters["guidanceScale"] = guidance_scale
shared_generation_parameters["guidance_scale"] = guidance_scale

if language is not None:
parameters["language"] = language
shared_generation_parameters["language"] = language

response = self._endpoint.predict(
instances=[instance],
parameters=parameters,
Expand All @@ -241,6 +249,7 @@ def generate_images(
negative_prompt: Optional[str] = None,
number_of_images: int = 1,
guidance_scale: Optional[float] = None,
language: Optional[str] = None,
seed: Optional[int] = None,
) -> "ImageGenerationResponse":
"""Generates images from text prompt.
Expand All @@ -255,6 +264,9 @@ def generate_images(
* 0-9 (low strength)
* 10-20 (medium strength)
* 21+ (high strength)
language: Language of the text prompt for the image. Default: None.
Supported values are `"en"` for English, `"hi"` for Hindi,
`"ja"` for Japanese, `"ko"` for Korean, and `"auto"` for automatic language detection.
seed: Image generation random seed.
Returns:
Expand All @@ -268,6 +280,7 @@ def generate_images(
width=None,
height=None,
guidance_scale=guidance_scale,
language=language,
seed=seed,
)

Expand All @@ -280,6 +293,7 @@ def edit_image(
negative_prompt: Optional[str] = None,
number_of_images: int = 1,
guidance_scale: Optional[float] = None,
language: Optional[str] = None,
seed: Optional[int] = None,
) -> "ImageGenerationResponse":
"""Edits an existing image based on text prompt.
Expand All @@ -296,6 +310,9 @@ def edit_image(
* 0-9 (low strength)
* 10-20 (medium strength)
* 21+ (high strength)
language: Language of the text prompt for the image. Default: None.
Supported values are `"en"` for English, `"hi"` for Hindi,
`"ja"` for Japanese, `"ko"` for Korean, and `"auto"` for automatic language detection.
seed: Image generation random seed.
Returns:
Expand All @@ -309,6 +326,7 @@ def edit_image(
seed=seed,
base_image=base_image,
mask=mask,
language=language,
)

def upscale_image(
Expand Down

0 comments on commit 791eff5

Please sign in to comment.