# IFChatPromptNode.py
import os
import sys
import json
import torch
import shutil
import base64
import platform
import importlib
import subprocess
import numpy as np
import folder_paths
from PIL import Image
import yaml
from io import BytesIO
import asyncio
from typing import List, Union, Dict, Any, Tuple, Optional
from .agent_tool import AgentTool
from .send_request import send_request
#from .transformers_api import TransformersModelManager 
import tempfile
import threading
import codecs
from aiohttp import web
from .graphRAG_module import GraphRAGapp
from .colpaliRAG_module import colpaliRAGapp
from .superflorence import FlorenceModule
from .utils import get_api_key, get_models, validate_models, clean_text, process_mask, load_placeholder_image, process_images_for_comfy
#from byaldi import RAGMultiModalModel 
# Set up logging
import logging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
# Add the ComfyUI directory to the Python path
comfy_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
sys.path.insert(0, comfy_path)

ifchat_prompt_node = None

try:
    from server import PromptServer

    @PromptServer.instance.routes.post("/IF_ChatPrompt/get_llm_models")
    async def get_llm_models_endpoint(request):
        data = await request.json()
        llm_provider = data.get("llm_provider")
        engine = llm_provider
        base_ip = data.get("base_ip")
        port = data.get("port")
        external_api_key = data.get("external_api_key")
        
        logger.debug(f"Received request for LLM models. Provider: {llm_provider}, External API key provided: {bool(external_api_key)}")

        if external_api_key:
            api_key = external_api_key
            logger.debug("Using provided external LLM API key")
        else:
            api_key_name = f"{llm_provider.upper()}_API_KEY"
            try:
                api_key = get_api_key(api_key_name, engine)
                logger.debug("Using API key from environment or .env file")
            except ValueError:
                logger.warning(f"No API key found for {llm_provider}. Attempting to proceed without an API key.")
                api_key = None

        models = get_models(engine, base_ip, port, api_key)
        logger.debug(f"Fetched {len(models)} models for {llm_provider}")
        return web.json_response(models)

    @PromptServer.instance.routes.post("/IF_ChatPrompt/get_embedding_models")
    async def get_embedding_models_endpoint(request):
        data = await request.json()
        embedding_provider = data.get("embedding_provider")
        engine = embedding_provider
        base_ip = data.get("base_ip")
        port = data.get("port")
        external_api_key = data.get("external_api_key")
        
        logger.debug(f"Received request for LLM models. Provider: {embedding_provider}, External API key provided: {bool(external_api_key)}")

        if external_api_key:
            api_key = external_api_key
            logger.debug("Using provided external LLM API key")
        else:
            api_key_name = f"{embedding_provider.upper()}_API_KEY"
            try:
                api_key = get_api_key(api_key_name, engine)
                logger.debug("Using API key from environment or .env file")
            except ValueError:
                logger.warning(f"No API key found for {embedding_provider}. Attempting to proceed without an API key.")
                api_key = None

        models = get_models(engine, base_ip, port, api_key)
        logger.debug(f"Fetched {len(models)} models for {embedding_provider}")
        return web.json_response(models)

    @PromptServer.instance.routes.post("/IF_ChatPrompt/upload_file")
    async def upload_file_route(request):
        try:
            reader = await request.multipart()

            rag_folder_name = None
            file_content = None
            filename = None

            # Process all parts of the multipart request
            while True:
                part = await reader.next()
                if part is None:
                    break
                if part.name == "rag_root_dir":
                    rag_folder_name = await part.text()
                elif part.filename:
                    filename = part.filename
                    file_content = await part.read()

            if not filename or not file_content or not rag_folder_name:
                return web.json_response({"status": "error", "message": "Missing file, filename, or RAG folder name"})

            node = IFChatPrompt()
            input_dir = os.path.join(node.rag_dir, rag_folder_name, "input")

            if not os.path.exists(input_dir):
                os.makedirs(input_dir, exist_ok=True)

            file_path = os.path.join(input_dir, filename)

            with open(file_path, 'wb') as f:
                f.write(file_content)

            logger.info(f"File uploaded to: {file_path}")
            return web.json_response({"status": "success", "message": f"File uploaded to: {file_path}"})

        except Exception as e:
            logger.error(f"Error in upload_file_route: {str(e)}")
            return web.json_response({"status": "error", "message": f"Error uploading file: {str(e)}"})

    @PromptServer.instance.routes.post("/IF_ChatPrompt/setup_and_initialize")
    async def setup_and_initialize(request):
        global ifchat_prompt_node
        
        data = await request.json()
        folder_name = data.get('folder_name', 'rag_data')
        
        if ifchat_prompt_node is None:
            ifchat_prompt_node = IFChatPrompt()
        
        init_result = await ifchat_prompt_node.graphrag_app.setup_and_initialize_folder(folder_name, data)
        
        ifchat_prompt_node.rag_folder_name = folder_name
        ifchat_prompt_node.colpali_app.set_rag_root_dir(folder_name)   
        
        return web.json_response(init_result)

    @PromptServer.instance.routes.post("/IF_ChatPrompt/run_indexer")
    async def run_indexer_endpoint(request):
        try:
            data = await request.json()
            logger.debug(f"Received indexing request with data: {data}")

            global ifchat_prompt_node  # Access the global instance

            # Set the rag_root_dir in both modules using the global instance
            ifchat_prompt_node.graphrag_app.set_rag_root_dir(data.get('rag_folder_name'))
            ifchat_prompt_node.colpali_app.set_rag_root_dir(data.get('rag_folder_name'))

            query_type = data.get('mode_type')
            logger.debug(f"Query type: {query_type}")

            logger.debug(f"Starting indexing process for query type: {query_type}")

            # Initialize the colpali_model before calling insert, using the global instance
            if query_type == 'colpali' or query_type == 'colqwen2' or query_type == 'colpali-v1.2':
                _ = ifchat_prompt_node.colpali_app.get_colpali_model(query_type)  # This will load or retrieve the cached model
                result = await ifchat_prompt_node.colpali_app.insert()
            else:
                result = await ifchat_prompt_node.graphrag_app.insert()

            logger.debug(f"Indexing process completed with result: {result}")

            if result:
                return web.json_response({"status": "success", "message": f"Indexing complete for {query_type}"})
            else:
                return web.json_response({"status": "error", "message": "Indexing failed. Check server logs."}, status=500)

        except Exception as e:
            logger.error(f"Error in run_indexer_endpoint: {str(e)}")
            return web.json_response({"status": "error", "message": f"Error during indexing: {str(e)}"}, status=500)
        
    @PromptServer.instance.routes.post("/IF_ChatPrompt/process_chat")
    async def process_chat_endpoint(request):
        try:
            data = await request.json()
            
            # Set default values for required arguments if not provided
            defaults = {
                "prompt": "",
                "assistant": "Cortana",  # Default assistant
                "neg_prompt": "Default",  # Default negative prompt
                "embellish_prompt": "Default",  # Default embellishment
                "style_prompt": "Default",  # Default style
                "llm_provider": "ollama",
                "llm_model": "",
                "base_ip": "localhost",
                "port": "11434",
                "embedding_model": "",
                "embedding_provider": "sentence_transformers"
            }
            
            # Update data with defaults for missing keys
            for key, default_value in defaults.items():
                if key not in data:
                    data[key] = default_value
                    
            global ifchat_prompt_node 
            result = await ifchat_prompt_node.process_chat(**data)
            
            return web.json_response(result)
            
        except Exception as e:
            logger.error(f"Error in process_chat_endpoint: {str(e)}")
            return web.json_response({
                "status": "error",
                "message": f"Error processing chat: {str(e)}",
                "Question": data.get("prompt", ""),
                "Response": f"Error: {str(e)}",
                "Negative": "",
                "Tool_Output": None,
                "Retrieved_Image": None,
                "Mask": None
            }, status=500)

    @PromptServer.instance.routes.post("/IF_ChatPrompt/load_index")
    async def load_index_route(request):
        try:
            data = await request.json()
            index_name = data.get('rag_folder_name')
            query_type = data.get('query_type')
            
            if not index_name:
                logger.error("No index name provided in the request.")
                return web.json_response({
                    "status": "error", 
                    "message": "No index name provided"
                })

            # Check if index exists in .byaldi directory
            byaldi_index_path = os.path.join(".byaldi", index_name)
            if not os.path.exists(byaldi_index_path):
                logger.error(f"Index not found in .byaldi: {byaldi_index_path}")
                return web.json_response({
                    "status": "error",
                    "message": f"Index {index_name} does not exist"
                })

            try:
                global ifchat_prompt_node
                if ifchat_prompt_node is None:
                    logger.debug("Initializing IFChatPrompt instance.")
                    ifchat_prompt_node = IFChatPrompt()

                if query_type in ['colpali', 'colqwen2', 'colpali-v1.2']:
                    logger.debug(f"Loading model for query type: {query_type}")
                    
                    # Clear any existing cached index
                    ifchat_prompt_node.colpali_app.cleanup_index()
                    
                    # First get the base model
                    colpali_model = ifchat_prompt_node.colpali_app.get_colpali_model(query_type)
                    
                    if colpali_model:
                        # Load and cache the new index
                        model = await ifchat_prompt_node.colpali_app._prepare_model(query_type, index_name)
                        if not model:
                            raise ValueError("Failed to load and cache index")
                        
                        # Set the RAG root directory
                        ifchat_prompt_node.colpali_app.set_rag_root_dir(index_name)
                        
                        logger.info(f"Successfully loaded and cached index: {index_name}")
                        return web.json_response({
                            "status": "success",
                            "message": f"Successfully loaded index: {index_name}",
                            "rag_root_dir": index_name
                        })
                    else:
                        logger.error("Failed to initialize ColPali model.")
                        raise ValueError("Failed to initialize ColPali model")
                
                else:
                    logger.error(f"Unsupported query type: {query_type}")
                    return web.json_response({
                        "status": "error",
                        "message": f"Query type {query_type} not supported for loading indexes"
                    })

            except Exception as e:
                logger.error(f"Error loading index {index_name}: {str(e)}")
                return web.json_response({
                    "status": "error",
                    "message": f"Error loading index: {str(e)}"
                })

        except Exception as e:
            logger.error(f"Error in load_index_route: {str(e)}")
            return web.json_response({
                "status": "error",
                "message": f"Error processing request: {str(e)}"
            })

    # Add this with the other routes
    @PromptServer.instance.routes.post("/IF_ChatPrompt/delete_index")
    async def delete_index_route(request):
        try:
            data = await request.json()
            index_name = data.get('rag_folder_name')
            
            if not index_name:
                return web.json_response({
                    "status": "error", 
                    "message": "No index name provided"
                })

            # Path to the index
            index_path = os.path.join(".byaldi", index_name)
            
            if not os.path.exists(index_path):
                return web.json_response({
                    "status": "error",
                    "message": f"Index {index_name} does not exist"
                })

            # Delete the index directory
            try:
                shutil.rmtree(index_path)
                logger.info(f"Successfully deleted index: {index_name}")
                return web.json_response({
                    "status": "success",
                    "message": f"Successfully deleted index: {index_name}"
                })
            except Exception as e:
                logger.error(f"Error deleting index {index_name}: {str(e)}")
                return web.json_response({
                    "status": "error",
                    "message": f"Error deleting index: {str(e)}"
                })

        except Exception as e:
            logger.error(f"Error in delete_index_route: {str(e)}")
            return web.json_response({
                "status": "error",
                "message": f"Error processing request: {str(e)}"
            })

except AttributeError:
    print("PromptServer.instance not available. Skipping route decoration for IF_ChatPrompt.")

class IFChatPrompt:

    def __init__(self):
        self.base_ip = "localhost"
        self.port = "11434"
        self.llm_provider = "ollama"
        self.embedding_provider = "sentence_transformers"
        self.llm_model = ""
        self.embedding_model = ""
        self.assistant = "None"
        self.random = False

        self.comfy_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
        self.rag_dir = os.path.join(folder_paths.base_path, "custom_nodes", "ComfyUI-IF_AI_tools", "IF_AI", "rag")
        self.presets_dir = os.path.join(folder_paths.base_path, "custom_nodes", "ComfyUI-IF_AI_tools", "IF_AI", "presets")
        
        self.stop_file = os.path.join(self.presets_dir, "stop_strings.json")
        self.assistants_file = os.path.join(self.presets_dir, "assistants.json")
        self.neg_prompts_file = os.path.join(self.presets_dir, "neg_prompts.json")
        self.embellish_prompts_file = os.path.join(self.presets_dir, "embellishments.json")
        self.style_prompts_file = os.path.join(self.presets_dir, "style_prompts.json")
        self.tasks_file = os.path.join(self.presets_dir, "florence_prompts.json")
        self.agents_dir = os.path.join(self.presets_dir, "agents")

        self.agent_tools = self.load_agent_tools()
        self.stop_strings = self.load_presets(self.stop_file)
        self.assistants = self.load_presets(self.assistants_file)
        self.neg_prompts = self.load_presets(self.neg_prompts_file)
        self.embellish_prompts = self.load_presets(self.embellish_prompts_file)
        self.style_prompts = self.load_presets(self.style_prompts_file)
        self.florence_prompts = self.load_presets(self.tasks_file)

        self.keep_alive = False
        self.seed = 94687328150
        self.messages = []
        self.history_steps = 10
        self.external_api_key = ""
        self.tool_input = ""
        self.prime_directives = None
        self.rag_folder_name = "rag_data"
        self.graphrag_app = GraphRAGapp()
        self.colpali_app = colpaliRAGapp()
        self.fix_json = True
        self.cached_colpali_model = None
        self.florence_app = FlorenceModule()
        self.florence_models = {}
        self.query_type = "global"  
        self.enable_RAG = False
        self.clear_history = False
        self.mode = False
        self.tool = "None"
        self.preset = "Default"
        self.precision = "fp16"
        self.task = None  
        self.attention = "sdpa" 
        self.aspect_ratio = "16:9"
        self.top_k_search = 3
        
        self.placeholder_image_path = os.path.join(folder_paths.base_path, "custom_nodes", "ComfyUI-IF_AI_tools", "IF_AI", "placeholder.png")

        if not os.path.exists(self.placeholder_image_path):
            placeholder = Image.new('RGB', (512, 512), color=(73, 109, 137))
            os.makedirs(os.path.dirname(self.placeholder_image_path), exist_ok=True)
            placeholder.save(self.placeholder_image_path)

    def load_presets(self, file_path: str) -> Dict[str, Any]:
        """
        Load JSON presets with support for multiple encodings.
        
        Args:
            file_path (str): Path to the JSON preset file
            
        Returns:
            Dict[str, Any]: Loaded JSON data or empty dict if loading fails
        """
        # List of encodings to try
        encodings = ['utf-8', 'utf-8-sig', 'latin1', 'cp1252', 'gbk']
        
        for encoding in encodings:
            try:
                with codecs.open(file_path, 'r', encoding=encoding) as f:
                    data = json.load(f)
                    
                    # If successful, write back with UTF-8 encoding to prevent future issues
                    try:
                        with codecs.open(file_path, 'w', encoding='utf-8') as out_f:
                            json.dump(data, out_f, ensure_ascii=False, indent=2)
                    except Exception as write_err:
                        print(f"Warning: Could not write back UTF-8 encoded file: {write_err}")
                        
                    return data
                    
            except UnicodeDecodeError:
                continue
            except json.JSONDecodeError as e:
                print(f"JSON parsing error with {encoding} encoding: {str(e)}")
                continue
            except Exception as e:
                print(f"Error loading presets from {file_path} with {encoding} encoding: {e}")
                continue
                
        print(f"Error: Failed to load {file_path} with any supported encoding")
        return {}

    def load_agent_tools(self):
        os.makedirs(self.agents_dir, exist_ok=True)
        agent_tools = {}
        try:
            for filename in os.listdir(self.agents_dir):
                if filename.endswith('.json'):
                    full_path = os.path.join(self.agents_dir, filename)
                    with open(full_path, 'r') as f:
                        try:
                            data = json.load(f)
                            if 'output_type' not in data:
                                data['output_type'] = None
                            agent_tool = AgentTool(**data)
                            agent_tool.load()
                            if agent_tool._class_instance is not None:
                                if agent_tool.python_function:
                                    agent_tools[agent_tool.name] = agent_tool
                                else:
                                    print(f"Warning: Agent tool {agent_tool.name} in {filename} does not have a python_function defined.")
                            else:
                                print(f"Failed to create class instance for {filename}")
                        except json.JSONDecodeError:
                            print(f"Error: Invalid JSON in {filename}")
                        except Exception as e:
                            print(f"Error loading {filename}: {str(e)}")
            return agent_tools
        except Exception as e:
            print(f"Warning: Error accessing agent tools directory: {str(e)}")
            return {}

    async def process_chat(
        self,
        prompt,
        llm_provider,
        llm_model,
        base_ip,
        port,
        assistant,
        neg_prompt,
        embellish_prompt,
        style_prompt,
        embedding_model,
        embedding_provider,
        external_api_key="",
        temperature=0.7,
        max_tokens=2048,
        seed=0,
        random=False,
        history_steps=10,
        keep_alive=False,
        top_k=40,
        top_p=0.2,
        repeat_penalty=1.1,
        stop_string=None,
        images=None,
        mode=True,
        clear_history=False,
        text_cleanup=True,
        tool=None,
        tool_input=None,
        prime_directives=None,
        enable_RAG=False,
        query_type="global",
        preset="Default",
        rag_folder_name=None,
        task=None,
        fill_mask=False,
        output_mask_select="",
        precision="fp16",
        attention="sdpa",
        aspect_ratio="16:9",
        top_k_search=3
    ):

        if external_api_key != "":
            llm_api_key = external_api_key
        else:
            llm_api_key = get_api_key(f"{llm_provider.upper()}_API_KEY", llm_provider)

        print(f"LLM API key: {llm_api_key[:5]}...")
        if prime_directives is not None:
            system_message_str = prime_directives
        else:
            system_message = self.assistants.get(assistant, "")
            system_message_str = json.dumps(system_message)

        # Validate LLM model
        validate_models(llm_model, llm_provider, "LLM", base_ip, port, llm_api_key)

        # Validate embedding model
        validate_models(embedding_model, embedding_provider, "embedding", base_ip, port, llm_api_key)

        # Handle history
        if clear_history:
            self.messages = []
        elif history_steps > 0:
            self.messages = self.messages[-history_steps:]
        
        messages = self.messages

        # Handle stop
        if stop_string is None or stop_string == "None":
            stop_content = None
        else:
            stop_content = self.stop_strings.get(stop_string, None)
        stop = stop_content

        if llm_provider not in ["ollama", "llamacpp", "vllm", "lmstudio", "gemeni"]:
            if llm_provider == "kobold":
                stop = stop_content + \
                    ["\n\n\n\n\n"] if stop_content else ["\n\n\n\n\n"]
            elif llm_provider == "mistral":
                stop = stop_content + \
                    ["\n\n"] if stop_content else ["\n\n"]
            else:
                stop = stop_content if stop_content else None
        # Handle tools
        try:
            if tool and tool != "None":
                selected_tool = self.agent_tools.get(tool)
                if not selected_tool:
                    raise ValueError(f"Invalid agent tool selected: {tool}")

                # Prepare tool execution message
                tool_message = f"Execute the {tool} tool with the following input: {prompt}"
                system_prompt = json.dumps(selected_tool.system_prompt)

                # Send request to LLM for tool execution
                generated_text =await send_request(
                    llm_provider=llm_provider,
                    base_ip=base_ip,
                    port=port,
                    images=images,
                    model=llm_model,
                    system_message=system_prompt,
                    user_message=tool_message,
                    messages=messages,
                    seed=seed,
                    temperature=temperature,
                    max_tokens=max_tokens,
                    random=random,
                    top_k=top_k,
                    top_p=top_p,
                    repeat_penalty=repeat_penalty,
                    stop=stop,
                    keep_alive=keep_alive,
                    llm_api_key=llm_api_key,
                )
                # Parse the generated text for function calls
                function_call = None
                try:
                    response_data = json.loads(generated_text)
                    if 'function_call' in response_data:
                        function_call = response_data['function_call']
                        generated_text = response_data['content']
                except json.JSONDecodeError:
                    pass  # The response wasn't JSON, so it's just the generated text

                # Execute the tool with the LLM's response
                tool_args = {
                    "input": prompt,
                    "llm_response": generated_text,
                    "function_call": function_call,
                    "omni_input": tool_input,
                    "name": selected_tool.name,
                    "description": selected_tool.description,
                    "system_prompt": selected_tool.system_prompt
                }
                tool_result = selected_tool.execute(tool_args)

                # Update messages
                messages.append({"role": "user", "content": prompt})
                messages.append({
                    "role": "assistant",
                    "content": json.dumps(tool_result) if isinstance(tool_result, dict) else str(tool_result)
                })

                # Process the tool output
                if isinstance(tool_result, dict):
                    if "error" in tool_result:
                        generated_text = f"Error in {tool}: {tool_result['error']}"
                        tool_output = None
                    elif selected_tool.output_type and selected_tool.output_type in tool_result:
                        tool_output = tool_result[selected_tool.output_type]
                        generated_text = f"Agent {tool} executed successfully. Output generated."
                    else:
                        tool_output = tool_result
                        generated_text = str(tool_output)
                else:
                    tool_output = tool_result
                    generated_text = str(tool_output)

                return {
                    "Question": prompt,
                    "Response": generated_text,
                    "Negative": self.neg_prompts.get(neg_prompt, ""),
                    "Tool_Output": tool_output,
                    "Retrieved_Image": None  # No image retrieved in tool execution
                }
            else:
                response = await self.generate_response(
                    enable_RAG,
                    query_type,
                    prompt,
                    preset,
                    llm_provider,
                    base_ip,
                    port,
                    images,
                    llm_model,
                    system_message_str,
                    messages,
                    temperature,
                    max_tokens,
                    random,
                    top_k,
                    top_p,
                    repeat_penalty,
                    stop,
                    seed,
                    keep_alive,
                    llm_api_key,
                    task,
                    fill_mask,
                    output_mask_select,
                    precision,
                    attention
                )

                generated_text = response.get("Response")
                selected_neg_prompt_name = neg_prompt 
                omni = response.get("Tool_Output")
                retrieved_image = response.get("Retrieved_Image")  
                retrieved_mask = response.get("Mask")

                
                # Update messages
                messages.append({"role": "user", "content": prompt})
                messages.append({"role": "assistant", "content": generated_text})
                
                text_result = str(generated_text).strip()

                if mode:
                    embellish_content = self.embellish_prompts.get(embellish_prompt, "").strip()
                    style_content = self.style_prompts.get(style_prompt, "").strip()
           
                    lines = [line.strip() for line in text_result.split('\n') if line.strip()]
                    combined_prompts = []
                    
                    for line in lines:
                        if text_cleanup:
                            line = clean_text(line)
                        formatted_line = f"{embellish_content} {line} {style_content}".strip()
                        combined_prompts.append(formatted_line)
                    
                    combined_prompt = "\n".join(formatted_line for formatted_line in combined_prompts)
                    # Handle negative prompts
                    if selected_neg_prompt_name == "AI_Fill":
                        try:
                            neg_system_message = self.assistants.get("NegativePromptEngineer")
                            if not neg_system_message:
                                logger.error("NegativePromptEngineer not found in assistants configuration")
                                negative_prompt = "Error: NegativePromptEngineer not configured"
                            else:
                                user_message = f"Generate negative prompts for the following prompt:\n{text_result}"
                                
                                system_message_str = json.dumps(neg_system_message)
                                
                                logger.info(f"Requesting negative prompts for prompt: {text_result[:100]}...")
                                
                                neg_response = await send_request(
                                    llm_provider=llm_provider,
                                    base_ip=base_ip,
                                    port=port,
                                    images=None, 
                                    llm_model=llm_model,
                                    system_message=system_message_str,
                                    user_message=user_message,
                                    messages=[],  # Fresh context for negative generation
                                    seed=seed,
                                    temperature=temperature,
                                    max_tokens=max_tokens,
                                    random=random,
                                    top_k=top_k,
                                    top_p=top_p,
                                    repeat_penalty=repeat_penalty,
                                    stop=stop,
                                    keep_alive=keep_alive,
                                    llm_api_key=llm_api_key
                                )
                                
                                logger.debug(f"Received negative prompt response: {neg_response}")
                                
                                if neg_response:
                                    negative_lines = []
                                    for line in neg_response.split('\n'):
                                        line = line.strip()
                                        if line:
                                            negative_lines.append(line)
                                    
                                    while len(negative_lines) < len(lines):
                                        negative_lines.append(negative_lines[-1] if negative_lines else "")
                                    negative_lines = negative_lines[:len(lines)]
                                    
                                    negative_prompt = "\n".join(negative_lines)
                                else:
                                    negative_prompt = "Error: Empty response from LLM"
                        except Exception as e:
                            logger.error(f"Error generating negative prompts: {str(e)}", exc_info=True)
                            negative_prompt = f"Error generating negative prompts: {str(e)}"
                        
                    elif neg_prompt != "None":
                        neg_content = self.neg_prompts.get(neg_prompt, "").strip()
                        negative_lines = [neg_content for _ in range(len(lines))]
                        negative_prompt = "\n".join(negative_lines)
                    else:
                        negative_prompt = ""  

                else:
                    combined_prompt = text_result
                    negative_prompt = ""

                try:
                    if isinstance(retrieved_image, torch.Tensor):
                        # Ensure it's in the correct format (B, C, H, W)
                        if retrieved_image.dim() == 3:  # Single image (C, H, W)
                            image_tensor = retrieved_image.unsqueeze(0)  # Add batch dimension
                        else:
                            image_tensor = retrieved_image  # Already batched

                        # Create matching batch masks
                        batch_size = image_tensor.shape[0]
                        height = image_tensor.shape[2]
                        width = image_tensor.shape[3]
                        
                        # Create white masks (all ones) for each image in batch
                        mask_tensor = torch.ones((batch_size, 1, height, width), 
                                              dtype=torch.float32, 
                                              device=image_tensor.device)
                        
                        if retrieved_mask is not None:
                            # If we have masks, process them to match the batch
                            if isinstance(retrieved_mask, torch.Tensor):
                                if retrieved_mask.dim() == 3:  # Single mask
                                    mask_tensor = retrieved_mask.unsqueeze(0)
                                else:
                                    mask_tensor = retrieved_mask
                            else:
                                # Process retrieved_mask if it's not a tensor
                                mask_tensor = process_mask(retrieved_mask, image_tensor)
                    else:
                        image_tensor, default_mask_tensor = process_images_for_comfy(
                            retrieved_image, 
                            self.placeholder_image_path,
                            response_key=None,
                            field_name=None
                        )
                        mask_tensor = default_mask_tensor

                        if retrieved_mask is not None:
                            mask_tensor = process_mask(retrieved_mask, image_tensor)
                    return (
                        prompt,
                        combined_prompt,
                        negative_prompt,
                        omni,
                        image_tensor,
                        mask_tensor,
                    )

                except Exception as e:
                    logger.error(f"Exception in image processing: {str(e)}", exc_info=True)
                    placeholder_image, placeholder_mask = load_placeholder_image(self.placeholder_image_path)
                    return (
                        prompt,
                        f"Error: {str(e)}",
                        "",
                        None,
                        placeholder_image,
                        placeholder_mask
                    )

        except Exception as e:
            logger.error(f"Exception occurred in process_chat: {str(e)}", exc_info=True)
            placeholder_image, placeholder_mask = load_placeholder_image(self.placeholder_image_path)
            return (
                prompt,
                f"Error: {str(e)}",
                "",
                None,
                placeholder_image,
                placeholder_mask
            )

    async def generate_response(
        self,
        enable_RAG,
        query_type,
        prompt,
        preset,
        llm_provider,
        base_ip,
        port,
        images,
        llm_model,
        system_message_str,
        messages,
        temperature,
        max_tokens,
        random,
        top_k,
        top_p,
        repeat_penalty,
        stop,
        seed,
        keep_alive,
        llm_api_key,
        task=None,
        fill_mask=False,
        output_mask_select="",
        precision="fp16",
        attention="sdpa",
    ):
        response_strategies = {
            "graphrag": self.graphrag_app.query,
            "colpali": self.colpali_app.query,
            "florence": self.florence_app.run_florence,
            "normal": lambda: send_request(
                llm_provider=llm_provider,
                base_ip=base_ip,
                port=port,
                images=images,
                llm_model=llm_model,
                system_message=system_message_str,
                user_message=prompt,
                messages=messages,
                seed=seed,
                temperature=temperature,
                max_tokens=max_tokens,
                random=random,
                top_k=top_k,
                top_p=top_p,
                repeat_penalty=repeat_penalty,
                stop=stop,
                keep_alive=keep_alive,
                llm_api_key=llm_api_key,
                tools=None,
                tool_choice=None,
                precision=precision,
                attention=attention
            ),
        }
        florence_tasks = list(self.florence_prompts.keys())
        if enable_RAG:
            if query_type == "colpali" or query_type == "colpali-v1.2" or query_type == "colqwen2":
                strategy = "colpali"
            else:  # For "global", "local", and "naive" query types
                strategy = "graphrag"
        elif task and task.lower() != 'none' and task in florence_tasks:
            strategy = "florence"
        else:
            strategy = "normal"

        print(f"Strategy: {strategy}")

        try:
            if strategy == "colpali":
                # Ensure the model is loaded before querying
                if self.cached_colpali_model is None:
                    self.cached_colpali_model = self.colpali_app.get_colpali_model(query_type)
                response = await response_strategies[strategy](prompt=prompt, query_type=query_type, system_message_str=system_message_str)
                return response
            elif strategy == "graphrag":
                response = await response_strategies[strategy](prompt=prompt, query_type=query_type, preset=preset) 
                return {
                        "Question": prompt,
                        "Response": response[0],
                        "Negative": "",
                        "Tool_Output": response[1],
                        "Retrieved_Image": None,
                        "Mask": None
                    }
            elif strategy == "florence":
                task_content = self.florence_prompts.get(task, "")
                response = await response_strategies[strategy](
                    images=images,
                    task=task,
                    task_prompt=task_content,
                    llm_model=llm_model,
                    precision=precision,
                    attention=attention,
                    fill_mask=fill_mask,
                    output_mask_select=output_mask_select,
                    keep_alive=keep_alive,
                    max_new_tokens=max_tokens,
                    temperature=temperature,
                    top_p=top_p,
                    top_k=top_k,
                    repetition_penalty=repeat_penalty,
                    seed=seed,
                    text_input=prompt,
                )
                print("Florence response:", response)
                return response
            else:
                response = await response_strategies[strategy]()
                print("Normal response:", response)
                return {
                    "Question": prompt,
                    "Response": response,
                    "Negative": "",
                    "Tool_Output": None,
                    "Retrieved_Image": None,
                    "Mask": None
                }

        except Exception as e:
            logger.error(f"Error processing strategy: {strategy}")
            return {
                "Question": prompt,
                "Response": f"Error processing task: {str(e)}",
                "Negative": "",
                "Tool_Output": {"error": str(e)},
                "Retrieved_Image": None,
                "Mask": None
            }

    def process_chat_wrapper(self, *args, **kwargs):
        try:
            loop = asyncio.get_event_loop()
        except RuntimeError:
            loop = asyncio.new_event_loop()
            asyncio.set_event_loop(loop)

        logger.debug(f"process_chat_wrapper kwargs: {kwargs}")
        logger.debug(f"External LLM API Key: {kwargs.get('external_api_key', 'Not provided')}")
        return loop.run_until_complete(self.process_chat(*args, **kwargs))

    @classmethod
    def INPUT_TYPES(cls):
        node = cls()
        return {
            "required": {
                "prompt": ("STRING", {"multiline": True, "default": "", "tooltip": "The main text input for the chat or query."}),
                "llm_provider": (["xai","llamacpp", "ollama", "kobold", "lmstudio", "textgen", "groq", "gemini", "openai", "anthropic", "mistral", "transformers"], {"default": node.llm_provider, "tooltip": "The provider of the language model to be used."}),
                "llm_model": ((), {"tooltip": "The specific language model to be used for processing."}),
                "base_ip": ("STRING", {"default": node.base_ip, "tooltip": "IP address of the LLM server."}),
                "port": ("STRING", {"default": node.port, "tooltip": "Port number for the LLM server connection."}),               
            },
            "optional": {
                "images": ("IMAGE", {"list": True, "tooltip": "Input image(s) for visual processing or context."}),
                "precision": (['fp16','bf16','fp32','int8','int4'],{"default": 'bf16', "tooltip": "Select preccision on Transformer models."}),
                "attention": (['flash_attention_2','sdpa','xformers', 'Shrek_COT_o1'],{"default": 'sdpa', "tooltip": "Select attention mechanism on Transformer models."}),
                "assistant": ([name for name in node.assistants.keys()], {"default": node.assistant, "tooltip": "The pre-defined assistant personality to use for responses."}),
                "tool": (["None"] + [name for name in node.agent_tools.keys()], {"default": "None", "tooltip": "Selects a specific tool or agent for task execution."}),
                "temperature": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 2.0, "step": 0.1, "tooltip": "Controls randomness in output generation. Higher values increase creativity but may reduce coherence."}),
                "max_tokens": ("INT", {"default": 2048, "min": 0, "max": 0xffffffffffffffff, "tooltip": "Maximum number of tokens to generate in the response."}),
                "top_k": ("INT", {"default": 40, "min": 0, "max": 100, "tooltip": "Limits the next token selection to the K most likely tokens."}),
                "top_p": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 1.0, "step": 0.1, "tooltip": "Cumulative probability cutoff for token selection."}),
                "repeat_penalty": ("FLOAT", {"default": 1.2, "min": 0.0, "max": 10.0, "step": 0.1, "tooltip": "Penalizes repetition in generated text."}),
                "stop_string": ([name for name in node.stop_strings.keys()], {"tooltip": "Specifies a string at which text generation should stop."}),
                "seed": ("INT", {"default": 94687328150, "min": 0, "max": 0xffffffffffffffff, "tooltip": "Random seed for reproducible outputs."}),
                "random": ("BOOLEAN", {"default": False, "label_on": "Seed", "label_off": "Temperature", "tooltip": "Toggles between using a fixed seed or temperature-based randomness."}),
                "history_steps": ("INT", {"default": 10, "min": 0, "max": 0xffffffffffffffff, "tooltip": "Number of previous conversation turns to consider for context."}),
                "clear_history": ("BOOLEAN", {"default": False, "label_on": "Clear History", "label_off": "Keep History", "tooltip": "Option to clear or retain conversation history."}),
                "keep_alive": ("BOOLEAN", {"default": False, "label_on": "Keeps Model on Memory", "label_off": "Unloads Model from Memory", "tooltip": "Determines whether to keep the model loaded in memory between calls."}),
                "text_cleanup": ("BOOLEAN", {"default": True, "label_on": "Clean Response", "label_off": "Raw Text", "tooltip": "Applies text cleaning to the generated output."}),
                "mode": ("BOOLEAN", {"default": False, "label_on": "Using SD Mode", "label_off": "Using Chat Mode", "tooltip": "Switches between Stable Diffusion prompt generation and standard chat mode."}),
                "embellish_prompt": ([name for name in node.embellish_prompts.keys()], {"tooltip": "Adds pre-defined embellishments to the prompt."}),
                "style_prompt": ([name for name in node.style_prompts.keys()], {"tooltip": "Applies a pre-defined style to the prompt."}),
                "neg_prompt": ([name for name in node.neg_prompts.keys()], {"tooltip": "Adds a negative prompt to guide what should be avoided in generation."}),              
                "fill_mask": ("BOOLEAN", {"default": False, "label_on": "Fill Mask", "label_off": "No Fill", "tooltip": "Option to fill masks for Florence tasks."}),
                "output_mask_select": ("STRING", {"default": ""}),
                "task": ([name for name in node.florence_prompts.keys()], {"default": "None", "tooltip": "Select a Florence task."}),
                "embedding_provider": (["llamacpp", "ollama", "kobold", "lmstudio", "textgen", "groq", "gemini", "openai", "anthropic", "mistral", "sentence_transformers"], {"default": node.embedding_provider, "tooltip": "Provider for text embedding model."}),
                "embedding_model": ((), {"tooltip": "Specific embedding model to use."}),
                "tool_input": ("OMNI", {"default": None, "tooltip": "Additional input for the selected tool."}),
                "prime_directives": ("STRING", {"forceInput": True, "tooltip": "System message or prime directive for the AI assistant."}),
                "external_api_key":("STRING", {"default": "", "tooltip": "If this is not empty, it will be used instead of the API key from the .env file. Make sure it is empty to use the .env file."}),
                "top_k_search": ("INT", {"default": 3, "min": 1, "max": 10, "tooltip": "Find top scored image(s) from RAG."}),
                "aspect_ratio": (["1:1", "9:16", "16:9"], {"default": "16:9", "tooltip": "Select the aspect ratio for the image."}),
                "enable_RAG": ("BOOLEAN", {"default": False, "label_on": "RAG is Enabled", "label_off": "RAG is Disabled", "tooltip": "Enables Retrieval-Augmented Generation for enhanced context."}),
                "query_type": (["global", "local", "naive", "colpali", "colqwen2", "colpali-v1.2"], {"default": "global", "tooltip": "Selects the type of query strategy for RAG."}),
                "preset": (["Default", "Detailed", "Quick", "Bullet", "Comprehensive", "High-Level", "Focused"], {"default": "Default"}),
            },
            "hidden": {
                "model": ("STRING", {"default": ""}),
                "rag_root_dir": ("STRING", {"default": "rag_data"})
            }
        }

    @classmethod
    def IS_CHANGED(cls, **kwargs):
        node = cls()

        llm_provider = kwargs.get('llm_provider', node.llm_provider)
        embedding_provider = kwargs.get('embedding_provider', node.embedding_provider)
        base_ip = kwargs.get('base_ip', node.base_ip)
        port = kwargs.get('port', node.port)
        query_type = kwargs.get('query_type', node.query_type)
        external_api_key = kwargs.get('external_api_key', '')
        task = kwargs.get('task', node.task)

        # Determine which API key to use
        def get_api_key_with_fallback(provider, external_api_key):
            if external_api_key and external_api_key != '':
                return external_api_key
            try:
                # print(f"Using {provider} API key from .env file")
                api_key = get_api_key(f"{provider.upper()}_API_KEY", provider)
                # print(f" {api_key} API key for {provider} found in .env file")
                return api_key

            except ValueError:
                return None

        api_key = get_api_key_with_fallback(llm_provider, external_api_key)

        # Check for changes
        llm_provider_changed = llm_provider != node.llm_provider
        embedding_provider_changed = embedding_provider != node.embedding_provider
        api_key_changed = external_api_key != node.external_api_key
        base_ip_changed = base_ip != node.base_ip
        port_changed = port != node.port
        query_type_changed = query_type != node.query_type
        task_changed = task != node.task

        # Always fetch new models if the provider, API key, base_ip, or port has changed
        if llm_provider_changed or api_key_changed or base_ip_changed or port_changed:
            try:
                new_llm_models = get_models(llm_provider, base_ip, port, api_key)
            except Exception as e:
                print(f"Error fetching LLM models: {e}")
                new_llm_models = []
            llm_model_changed = new_llm_models != node.llm_model
        else:
            llm_model_changed = False

        if embedding_provider_changed or api_key_changed or base_ip_changed or port_changed:
            try:
                new_embedding_models = get_models(embedding_provider, base_ip, port, api_key)
            except Exception as e:
                print(f"Error fetching embedding models: {e}")
                new_embedding_models = []
            embedding_model_changed = new_embedding_models != node.embedding_model
        else:
            embedding_model_changed = False

        if (llm_provider_changed or embedding_provider_changed or llm_model_changed or 
            embedding_model_changed or query_type_changed or task_changed or api_key_changed or
            base_ip_changed or port_changed):

            node.llm_provider = llm_provider
            node.embedding_provider = embedding_provider
            node.base_ip = base_ip
            node.port = port
            node.external_api_key = external_api_key
            node.query_type = query_type
            node.task = task

            if llm_model_changed:
                node.llm_model = new_llm_models
            if embedding_model_changed:
                node.embedding_model = new_embedding_models

            # Update other attributes
            for attr in ['seed', 'random', 'history_steps', 'clear_history', 'mode', 
                        'keep_alive', 'tool', 'enable_RAG', 'preset']:
                setattr(node, attr, kwargs.get(attr, getattr(node, attr)))

            return True

        return False

    RETURN_TYPES = ("STRING", "STRING", "STRING", "OMNI", "IMAGE", "MASK")
    RETURN_NAMES = ("Question", "Response", "Negative", "Tool_Output", "Retrieved_Image", "Mask")

    OUTPUT_TOOLTIPS = (
        "The original input question or prompt.",
        "The generated response from the language model.",
        "The negative prompt used (if applicable) for guiding image generation.",
        "Output from the selected tool, which can be code or any other data type.",
        "An image retrieved by the RAG system, if applicable.",
        "Mask image generated by Florence tasks."
    )
    FUNCTION = "process_chat_wrapper"
    OUTPUT_NODE = True
    CATEGORY = "ImpactFrames💥🎞️/IF_tools"
    DESCRIPTION = "ComfyUI, Support API and Local LLM providers and RAG capabilities. Processes text prompts, handles image inputs, and integrates with different language models and indexing strategies."