Skip to content

Commit

Permalink
Comment out unused imports.
Browse files Browse the repository at this point in the history
  • Loading branch information
monorimet committed Jun 26, 2024
1 parent 3ba2cfd commit 0b17046
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 15 deletions.
12 changes: 0 additions & 12 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1029,19 +1029,8 @@ def __call__(
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
<<<<<<< HEAD

hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
=======
# trace_tensor("query", query[0,0,0])
# trace_tensor("key", key[0,0,0])
# trace_tensor("value", value[0,0,0])
hidden_states = hidden_states = F.scaled_dot_product_attention(
query, key, value, dropout_p=0.0, is_causal=False
)
#trace_tensor("attn_out", hidden_states[0,0,0,0])

>>>>>>> ea85c86bc (SD3 bringup)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)

Expand All @@ -1051,7 +1040,6 @@ def __call__(
hidden_states[:, residual.shape[1] :],
)
hidden_states_cl = hidden_states.clone()
#trace_tensor("attn_out", hidden_states_cl[0,0,0])
# linear proj
hidden_states = attn.to_out[0](hidden_states_cl)
# dropout
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/models/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from .activations import FP32SiLU, get_activation
from .attention_processor import Attention

from shark_turbine.ops.iree import trace_tensor


def get_timestep_embedding(
timesteps: torch.Tensor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -813,7 +813,8 @@ def __call__(
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)

print("prompt_embeds", prompt_embeds)
print("pooled_prompt_embeds", pooled_prompt_embeds)
# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
Expand All @@ -831,6 +832,8 @@ def __call__(
generator,
latents,
)
print(latents)
print(timesteps)

# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ..utils import BaseOutput, logging
from ..utils.torch_utils import randn_tensor
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
from shark_turbine.ops.iree import trace_tensor
#from shark_turbine.ops.iree import trace_tensor

logger = logging.get_logger(__name__) # pylint: disable=invalid-name

Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/schedulers/scheduling_pndm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
from shark_turbine.ops.iree import trace_tensor
#from shark_turbine.ops.iree import trace_tensor

# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
Expand Down

0 comments on commit 0b17046

Please sign in to comment.