Skip to content

Commit

Permalink
feat: "lora_name" as input
Browse files Browse the repository at this point in the history
  • Loading branch information
TimPietrusky committed Oct 5, 2024
1 parent e248024 commit b814d99
Showing 1 changed file with 68 additions and 53 deletions.
121 changes: 68 additions & 53 deletions worker_runpod.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,6 @@
from comfy import model_management
import base64


def download_file(url, save_dir="/content/ComfyUI/models/loras"):
os.makedirs(save_dir, exist_ok=True)
file_name = url.split("/")[-1]
file_path = os.path.join(save_dir, file_name)
response = requests.get(url)
response.raise_for_status()
with open(file_path, "wb") as file:
file.write(response.content)
return file_path


# Initialize Model Loaders
DualCLIPLoader = NODE_CLASS_MAPPINGS["DualCLIPLoader"]()
UNETLoader = NODE_CLASS_MAPPINGS["UNETLoader"]()
Expand Down Expand Up @@ -68,66 +56,93 @@ def closestNumber(n, m):
def generate(input):
values = input["input"]

positive_prompt = values["positive_prompt"]
width = values["width"]
height = values["height"]
seed = values["seed"]
steps = values["steps"]
guidance = values["guidance"]
lora_strength_model = values["lora_strength_model"]
lora_strength_clip = values["lora_strength_clip"]
sampler_name = values["sampler_name"]
scheduler = values["scheduler"]
lora_url = values["lora_url"]

# Download and load LoRa model
lora_file = download_file(lora_url)
lora_file = os.path.basename(lora_file)
positive_prompt = values.get("positive_prompt", "")
width = values.get("width", 512)
height = values.get("height", 512)
seed = values.get("seed", 0)
steps = values.get("steps", 50)
guidance = values.get("guidance", 7.5)
lora_strength_model = values.get("lora_strength_model", 0.8)
lora_strength_clip = values.get("lora_strength_clip", 0.8)
sampler_name = values.get("sampler_name", "Euler")
scheduler = values.get("scheduler", "default")
job_id = values.get("job_id", "test-job-123")
lora_name = values.get("lora_name", "zanshou-kin-flux-ueno-manga-style.safetensors")

# Path to the LoRA model based on lora_name
lora_file_path = f"models/loras/{lora_name}"

# Validate if the specified LoRA model exists
if not os.path.exists(lora_file_path):
error_response = {
"jobId": job_id,
"result": f"FAILED: LoRA model '{lora_name}' not found.",
"status": "FAILED",
}
print(
f"Error: LoRA model '{lora_name}' does not exist at path '{lora_file_path}'."
)
return error_response

# Handle seed
if seed == 0:
random.seed(int(time.time()))
seed = random.randint(0, 18446744073709551615)
print(seed)

global unet, clip
unet_lora, clip_lora = LoraLoader.load_lora(
unet, clip, lora_file, lora_strength_model, lora_strength_clip
)
cond, pooled = clip_lora.encode_from_tokens(
clip_lora.tokenize(positive_prompt), return_pooled=True
)
cond = [[cond, {"pooled_output": pooled}]]
cond = FluxGuidance.append(cond, guidance)[0]
noise = RandomNoise.get_noise(seed)[0]
guider = BasicGuider.get_guider(unet_lora, cond)[0]
sampler = KSamplerSelect.get_sampler(sampler_name)[0]
sigmas = BasicScheduler.get_sigmas(unet_lora, scheduler, steps, 1.0)[0]
latent_image = EmptyLatentImage.generate(
closestNumber(width, 16), closestNumber(height, 16)
)[0]
sample, sample_denoised = SamplerCustomAdvanced.sample(
noise, guider, sampler, sigmas, latent_image
)
decoded = VAEDecode.decode(vae, sample)[0].detach()
image_path = "/content/flux.png"
Image.fromarray(np.array(decoded * 255, dtype=np.uint8)[0]).save(image_path)
print(f"Using seed: {seed}")

try:
# Load LoRA models from the specified file
unet_lora, clip_lora = LoraLoader.load_lora(
unet, clip, lora_file_path, lora_strength_model, lora_strength_clip
)

# Encode the positive prompt
cond, pooled = clip_lora.encode_from_tokens(
clip_lora.tokenize(positive_prompt), return_pooled=True
)
cond = [[cond, {"pooled_output": pooled}]]
cond = FluxGuidance.append(cond, guidance)[0]

# Generate noise based on the seed
noise = RandomNoise.get_noise(seed)[0]

# Initialize the guider and sampler
guider = BasicGuider.get_guider(unet_lora, cond)[0]
sampler = KSamplerSelect.get_sampler(sampler_name)[0]
sigmas = BasicScheduler.get_sigmas(unet_lora, scheduler, steps, 1.0)[0]

# Generate an empty latent image
latent_image = EmptyLatentImage.generate(
closestNumber(width, 16), closestNumber(height, 16)
)[0]

# Perform the sampling
sample, sample_denoised = SamplerCustomAdvanced.sample(
noise, guider, sampler, sigmas, latent_image
)

# Decode the image using VAE
decoded = VAEDecode.decode(vae, sample)[0].detach()

# Save the image to a file
image_path = "flux.png"
Image.fromarray(np.array(decoded * 255, dtype=np.uint8)[0]).save(image_path)

# Open and encode the image in Base64
with open(image_path, "rb") as image_file:
encoded_image = base64.b64encode(image_file.read()).decode("utf-8")

# Prepare the response
response = {"jobId": values["job_id"], "image": encoded_image, "status": "DONE"}
response = {"jobId": job_id, "image": encoded_image, "status": "DONE"}
return response

except Exception as e:
error_response = {
"jobId": values.get("job_id", "unknown"),
"jobId": job_id,
"result": f"FAILED: {str(e)}",
"status": "FAILED",
}
print(f"Error processing job {job_id}: {str(e)}")
return error_response

finally:
Expand Down

0 comments on commit b814d99

Please sign in to comment.