diff --git a/worker_runpod.py b/worker_runpod.py index 438e92f..78f9211 100644 --- a/worker_runpod.py +++ b/worker_runpod.py @@ -15,18 +15,6 @@ from comfy import model_management import base64 - -def download_file(url, save_dir="/content/ComfyUI/models/loras"): - os.makedirs(save_dir, exist_ok=True) - file_name = url.split("/")[-1] - file_path = os.path.join(save_dir, file_name) - response = requests.get(url) - response.raise_for_status() - with open(file_path, "wb") as file: - file.write(response.content) - return file_path - - # Initialize Model Loaders DualCLIPLoader = NODE_CLASS_MAPPINGS["DualCLIPLoader"]() UNETLoader = NODE_CLASS_MAPPINGS["UNETLoader"]() @@ -68,66 +56,93 @@ def closestNumber(n, m): def generate(input): values = input["input"] - positive_prompt = values["positive_prompt"] - width = values["width"] - height = values["height"] - seed = values["seed"] - steps = values["steps"] - guidance = values["guidance"] - lora_strength_model = values["lora_strength_model"] - lora_strength_clip = values["lora_strength_clip"] - sampler_name = values["sampler_name"] - scheduler = values["scheduler"] - lora_url = values["lora_url"] - - # Download and load LoRa model - lora_file = download_file(lora_url) - lora_file = os.path.basename(lora_file) + positive_prompt = values.get("positive_prompt", "") + width = values.get("width", 512) + height = values.get("height", 512) + seed = values.get("seed", 0) + steps = values.get("steps", 50) + guidance = values.get("guidance", 7.5) + lora_strength_model = values.get("lora_strength_model", 0.8) + lora_strength_clip = values.get("lora_strength_clip", 0.8) + sampler_name = values.get("sampler_name", "Euler") + scheduler = values.get("scheduler", "default") + job_id = values.get("job_id", "test-job-123") + lora_name = values.get("lora_name", "zanshou-kin-flux-ueno-manga-style.safetensors") + + # Path to the LoRA model based on lora_name + lora_file_path = f"models/loras/{lora_name}" + + # Validate if the specified LoRA model exists + if not os.path.exists(lora_file_path): + error_response = { + "jobId": job_id, + "result": f"FAILED: LoRA model '{lora_name}' not found.", + "status": "FAILED", + } + print( + f"Error: LoRA model '{lora_name}' does not exist at path '{lora_file_path}'." + ) + return error_response # Handle seed if seed == 0: random.seed(int(time.time())) seed = random.randint(0, 18446744073709551615) - print(seed) - - global unet, clip - unet_lora, clip_lora = LoraLoader.load_lora( - unet, clip, lora_file, lora_strength_model, lora_strength_clip - ) - cond, pooled = clip_lora.encode_from_tokens( - clip_lora.tokenize(positive_prompt), return_pooled=True - ) - cond = [[cond, {"pooled_output": pooled}]] - cond = FluxGuidance.append(cond, guidance)[0] - noise = RandomNoise.get_noise(seed)[0] - guider = BasicGuider.get_guider(unet_lora, cond)[0] - sampler = KSamplerSelect.get_sampler(sampler_name)[0] - sigmas = BasicScheduler.get_sigmas(unet_lora, scheduler, steps, 1.0)[0] - latent_image = EmptyLatentImage.generate( - closestNumber(width, 16), closestNumber(height, 16) - )[0] - sample, sample_denoised = SamplerCustomAdvanced.sample( - noise, guider, sampler, sigmas, latent_image - ) - decoded = VAEDecode.decode(vae, sample)[0].detach() - image_path = "/content/flux.png" - Image.fromarray(np.array(decoded * 255, dtype=np.uint8)[0]).save(image_path) + print(f"Using seed: {seed}") try: + # Load LoRA models from the specified file + unet_lora, clip_lora = LoraLoader.load_lora( + unet, clip, lora_file_path, lora_strength_model, lora_strength_clip + ) + + # Encode the positive prompt + cond, pooled = clip_lora.encode_from_tokens( + clip_lora.tokenize(positive_prompt), return_pooled=True + ) + cond = [[cond, {"pooled_output": pooled}]] + cond = FluxGuidance.append(cond, guidance)[0] + + # Generate noise based on the seed + noise = RandomNoise.get_noise(seed)[0] + + # Initialize the guider and sampler + guider = BasicGuider.get_guider(unet_lora, cond)[0] + sampler = KSamplerSelect.get_sampler(sampler_name)[0] + sigmas = BasicScheduler.get_sigmas(unet_lora, scheduler, steps, 1.0)[0] + + # Generate an empty latent image + latent_image = EmptyLatentImage.generate( + closestNumber(width, 16), closestNumber(height, 16) + )[0] + + # Perform the sampling + sample, sample_denoised = SamplerCustomAdvanced.sample( + noise, guider, sampler, sigmas, latent_image + ) + + # Decode the image using VAE + decoded = VAEDecode.decode(vae, sample)[0].detach() + + # Save the image to a file + image_path = "flux.png" + Image.fromarray(np.array(decoded * 255, dtype=np.uint8)[0]).save(image_path) + # Open and encode the image in Base64 with open(image_path, "rb") as image_file: encoded_image = base64.b64encode(image_file.read()).decode("utf-8") # Prepare the response - response = {"jobId": values["job_id"], "image": encoded_image, "status": "DONE"} + response = {"jobId": job_id, "image": encoded_image, "status": "DONE"} return response except Exception as e: error_response = { - "jobId": values.get("job_id", "unknown"), + "jobId": job_id, "result": f"FAILED: {str(e)}", "status": "FAILED", } + print(f"Error processing job {job_id}: {str(e)}") return error_response finally: