Skip to content

Commit

Permalink
mps: Alternative implementation for repeat_interleave (open-mmlab…
Browse files Browse the repository at this point in the history
…#766)

* mps: alt. implementation for repeat_interleave

* style

* Bump mps version of PyTorch in the documentation.

* Apply suggestions from code review

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* Simplify: do not check for device.

* style

* Fix repeat dimensions:

- The unconditional embeddings are always created from a single prompt.
- I was shadowing the batch_size var.

* Split long lines as suggested by Suraj.

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Suraj Patil <surajp815@gmail.com>
  • Loading branch information
3 people authored Oct 11, 2022
1 parent 757babf commit 24b8b5c
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
2 changes: 1 addition & 1 deletion docs/source/optimization/mps.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ specific language governing permissions and limitations under the License.
- Mac computer with Apple silicon (M1/M2) hardware.
- macOS 12.3 or later.
- arm64 version of Python.
- PyTorch [Preview (Nightly)](https://pytorch.org/get-started/locally/), version `1.13.0.dev20220830` or later.
- PyTorch [Preview (Nightly)](https://pytorch.org/get-started/locally/), version `1.14.0.dev20221007` or later.

## Inference Pipeline

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,8 +218,10 @@ def __call__(
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]

# duplicate text embeddings for each generation per prompt
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
# duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = text_embeddings.shape
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)

# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
Expand Down Expand Up @@ -256,8 +258,10 @@ def __call__(
)
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]

# duplicate unconditional embeddings for each generation per prompt
uncond_embeddings = uncond_embeddings.repeat_interleave(batch_size * num_images_per_prompt, dim=0)
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)

# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
Expand Down

0 comments on commit 24b8b5c

Please sign in to comment.