Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix is inference mode (PaddlePaddle#711)
Browse files Browse the repository at this point in the history
paddle 3.0beta 不支持 from paddle.incubate.jit.is_inference_mode,这里修复下
zhoutianzi666 authored Sep 18, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent 0f37251 commit ef51185
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions ppdiffusers/ppdiffusers/pipelines/dit/pipeline_dit.py
Original file line number Diff line number Diff line change
@@ -27,6 +27,14 @@
from ...utils.paddle_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput

try:
# paddle.incubate.jit.inference is available in paddle develop but not in paddle 3.0beta, so we add a try except.
from paddle.incubate.jit import is_inference_mode
except:

def is_inference_mode(func):
return False


class DiTPipeline(DiffusionPipeline):
r"""
@@ -192,7 +200,7 @@ def __call__(
)
# predict noise model_output
noise_pred_out = self.transformer(latent_model_input, timestep=timesteps, class_labels=class_labels_input)
if paddle.incubate.jit.is_inference_mode(self.transformer):
if is_inference_mode(self.transformer):
# self.transformer run in paddle inference.
noise_pred = noise_pred_out
else:
@@ -227,7 +235,7 @@ def __call__(
latents = 1 / self.vae.config.scaling_factor * latents

samples_out = self.vae.decode(latents)
if paddle.incubate.jit.is_inference_mode(self.vae.decode):
if is_inference_mode(self.vae.decode):
# self.vae.decode run in paddle inference.
samples = samples_out
else:

0 comments on commit ef51185

Please sign in to comment.