Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Image to image #25

Merged
merged 20 commits into from
Nov 10, 2022
1 change: 1 addition & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,5 +94,6 @@
)

from .pipelines import OneFlowStableDiffusionPipeline
from .pipelines import OneFlowStableDiffusionImg2ImgPipeline
from .pipeline_oneflow_utils import OneFlowDiffusionPipeline
from .modeling_oneflow_utils import OneFlowModelMixin
74 changes: 37 additions & 37 deletions src/diffusers/models/attention_oneflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,27 +65,26 @@ def forward(self, hidden_states):
key_proj = self.key(hidden_states)
value_proj = self.value(hidden_states)

'''
if query_proj.device == torch.device("cpu"):
# transpose
query_states = self.transpose_for_scores(query_proj)
key_states = self.transpose_for_scores(key_proj)
value_states = self.transpose_for_scores(value_proj)

# transpose
query_states = self.transpose_for_scores(query_proj)
key_states = self.transpose_for_scores(key_proj)
value_states = self.transpose_for_scores(value_proj)
# get scores
scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads))

# get scores
scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads))
attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale)
attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)

attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale)
attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
# compute attention output
hidden_states = torch.matmul(attention_probs, value_states)

# compute attention output
hidden_states = torch.matmul(attention_probs, value_states)

hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous()
new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,)
hidden_states = hidden_states.view(new_hidden_states_shape)
'''
hidden_states = torch._C.fused_multi_head_attention_inference(query_proj, key_proj, value_proj, self.num_heads)
hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous()
new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,)
hidden_states = hidden_states.view(new_hidden_states_shape)
else:
hidden_states = torch._C.fused_multi_head_attention_inference(query_proj, key_proj, value_proj, self.num_heads)
# compute next hidden_states
hidden_states = self.proj_attn(hidden_states)
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
Expand Down Expand Up @@ -259,23 +258,23 @@ def forward(self, hidden_states, context=None, mask=None):
key = self.to_k(context)
value = self.to_v(context)

'''
dim = query.shape[-1]
if query.device == torch.device("cpu"):
dim = query.shape[-1]

query = self.reshape_heads_to_batch_dim(query)
key = self.reshape_heads_to_batch_dim(key)
value = self.reshape_heads_to_batch_dim(value)
query = self.reshape_heads_to_batch_dim(query)
key = self.reshape_heads_to_batch_dim(key)
value = self.reshape_heads_to_batch_dim(value)

# TODO(PVP) - mask is currently never used. Remember to re-implement when used
# TODO(PVP) - mask is currently never used. Remember to re-implement when used

# attention, what we cannot get enough of
# attention, what we cannot get enough of

if self._slice_size is None or query.shape[0] // self._slice_size == 1:
hidden_states = self._attention(query, key, value)
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
hidden_states = self._attention(query, key, value)
else:
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim)
else:
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim)
'''
hidden_states = torch._C.fused_multi_head_attention_inference(query, key, value, self.heads)
hidden_states = torch._C.fused_multi_head_attention_inference(query, key, value, self.heads)

return self.to_out(hidden_states)

Expand Down Expand Up @@ -349,13 +348,14 @@ def __init__(self, dim_in: int, dim_out: int):
self.proj = nn.Linear(dim_in, dim_out * 2)

def forward(self, hidden_states):
x_shape = hidden_states.shape
if len(x_shape) != 2:
hidden_states = hidden_states.reshape(-1, x_shape[-1])
out = torch._C.fused_geglu(hidden_states, self.proj.weight, self.proj.bias)
if len(x_shape) != 2:
out_shape = x_shape[0:len(x_shape) -1 ] + (-1, )
out = out.reshape(out_shape)
return out
if hasattr(torch._C, "fused_geglu"):
x_shape = hidden_states.shape
if len(x_shape) != 2:
hidden_states = hidden_states.reshape(-1, x_shape[-1])
out = torch._C.fused_geglu(hidden_states, self.proj.weight, self.proj.bias)
if len(x_shape) != 2:
out_shape = x_shape[0:len(x_shape) -1 ] + (-1, )
out = out.reshape(out_shape)
return out
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
return hidden_states * F.gelu(gate)
1 change: 1 addition & 0 deletions src/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@
from .stable_diffusion import FlaxStableDiffusionPipeline

from .stable_diffusion import OneFlowStableDiffusionPipeline
from .stable_diffusion import OneFlowStableDiffusionImg2ImgPipeline
1 change: 1 addition & 0 deletions src/diffusers/pipelines/stable_diffusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,5 @@ class FlaxStableDiffusionPipelineOutput(BaseOutput):
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker

from .pipeline_stable_diffusion_oneflow import OneFlowStableDiffusionPipeline
from .pipeline_stable_diffusion_img2img_oneflow import OneFlowStableDiffusionImg2ImgPipeline
from .safety_checker_oneflow import OneFlowStableDiffusionSafetyChecker
Loading