-
Notifications
You must be signed in to change notification settings - Fork 683
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
🚀 Add VLM based Anomaly Model (#2344)
* [Draft] Llm on (#2165) * Add TaskType Explanation Signed-off-by: Bepitic <bepitic@gmail.com> * Add llm model Signed-off-by: Bepitic <bepitic@gmail.com> * add ollama Signed-off-by: Bepitic <bepitic@gmail.com> * better description for descr in title Signed-off-by: Bepitic <bepitic@gmail.com> * add text of llm into imageResult visualization * add text of llm into imageResult visualization Signed-off-by: Bepitic <bepitic@gmail.com> * latest changes Signed-off-by: Bepitic <bepitic@gmail.com> * add wip llava/llava_next Signed-off-by: Bepitic <bepitic@gmail.com> * add init Signed-off-by: Bepitic <bepitic@gmail.com> * add text of llm into imageResult visualization Signed-off-by: Bepitic <bepitic@gmail.com> * latest changes Signed-off-by: Bepitic <bepitic@gmail.com> * upd Lint Signed-off-by: Bepitic <bepitic@gmail.com> * fix visualization with description Signed-off-by: Bepitic <bepitic@gmail.com> * show the images every batch Signed-off-by: Bepitic <bepitic@gmail.com> * fix docstring and error management Signed-off-by: Bepitic <bepitic@gmail.com> * Add compatibility for TaskType.EXPLANATION. Signed-off-by: Bepitic <bepitic@gmail.com> * Remove, show in the engine-Visualization. * fix visualization and llm openai multishot. * fix Circular import problem * Add HugginFace To LLavaNext Signed-off-by: Bepitic <bepitic@gmail.com> --------- Signed-off-by: Bepitic <bepitic@gmail.com> * 🔨 Scaffold for refactor (#2340) * initial scafold Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> * Apply PR comments Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> * rename dir Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> --------- Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> * Add ChatGPT (#2341) * initial scafold Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> * Apply PR comments Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> * rename dir Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> * delete llm_ollama Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> * Add ChatGPT Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> * Add ChatGPT Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> * Remove LLM model Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> --------- Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> * Add Huggingface (#2343) * initial scafold Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> * Apply PR comments Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> * rename dir Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> * delete llm_ollama Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> * Add ChatGPT Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> * Add ChatGPT Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> * Remove LLM model Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> * Add transformers Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> * Remove llava Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> --------- Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> * 🔨 Minor Refactor (#2345) Refactor Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> * undo changes Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> * undo changes Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> * undo changes to image.py Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> * Add explanation visualizer (#2351) * Add explanation visualizer Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> * bug-fix Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> --------- Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> * 🔨 Allow setting API keys from env (#2353) Allow setting API keys from env Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> * 🧪 Add tests (#2355) * Add tests Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> * remove explanation task type Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> --------- Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> * minor fixes Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> * Update changelog Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> * Fix tests Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> * Address PR comments Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> * update name Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> * Update src/anomalib/models/image/vlm_ad/lightning_model.py Co-authored-by: Samet Akcay <samet.akcay@intel.com> * update name Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> --------- Signed-off-by: Bepitic <bepitic@gmail.com> Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> Co-authored-by: Paco <bepitic@gmail.com> Co-authored-by: Samet Akcay <samet.akcay@intel.com>
- Loading branch information
1 parent
6eeb7f6
commit 3a403ae
Showing
17 changed files
with
603 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
"""Visual Anomaly Model.""" | ||
|
||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from .lightning_model import VlmAd | ||
|
||
__all__ = ["VlmAd"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
"""VLM backends.""" | ||
|
||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from .base import Backend | ||
from .chat_gpt import ChatGPT | ||
from .huggingface import Huggingface | ||
from .ollama import Ollama | ||
|
||
__all__ = ["Backend", "ChatGPT", "Huggingface", "Ollama"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
"""Base backend.""" | ||
|
||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from abc import ABC, abstractmethod | ||
from pathlib import Path | ||
|
||
from anomalib.models.image.vlm_ad.utils import Prompt | ||
|
||
|
||
class Backend(ABC): | ||
"""Base backend.""" | ||
|
||
@abstractmethod | ||
def __init__(self, model_name: str) -> None: | ||
"""Initialize the backend.""" | ||
|
||
@abstractmethod | ||
def add_reference_images(self, image: str | Path) -> None: | ||
"""Add reference images for k-shot.""" | ||
|
||
@abstractmethod | ||
def predict(self, image: str | Path, prompt: Prompt) -> str: | ||
"""Predict the anomaly label.""" | ||
|
||
@property | ||
@abstractmethod | ||
def num_reference_images(self) -> int: | ||
"""Get the number of reference images.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
"""ChatGPT backend.""" | ||
|
||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import base64 | ||
import logging | ||
import os | ||
from pathlib import Path | ||
from typing import TYPE_CHECKING | ||
|
||
from dotenv import load_dotenv | ||
from lightning_utilities.core.imports import package_available | ||
|
||
from anomalib.models.image.vlm_ad.utils import Prompt | ||
|
||
from .base import Backend | ||
|
||
if package_available("openai"): | ||
from openai import OpenAI | ||
else: | ||
OpenAI = None | ||
|
||
if TYPE_CHECKING: | ||
from openai.types.chat import ChatCompletion | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class ChatGPT(Backend): | ||
"""ChatGPT backend.""" | ||
|
||
def __init__(self, model_name: str, api_key: str | None = None) -> None: | ||
"""Initialize the ChatGPT backend.""" | ||
self._ref_images_encoded: list[str] = [] | ||
self.model_name: str = model_name | ||
self._client: OpenAI | None = None | ||
self.api_key = self._get_api_key(api_key) | ||
|
||
@property | ||
def client(self) -> OpenAI: | ||
"""Get the OpenAI client.""" | ||
if OpenAI is None: | ||
msg = "OpenAI is not installed. Please install it to use ChatGPT backend." | ||
raise ImportError(msg) | ||
if self._client is None: | ||
self._client = OpenAI(api_key=self.api_key) | ||
return self._client | ||
|
||
def add_reference_images(self, image: str | Path) -> None: | ||
"""Add reference images for k-shot.""" | ||
self._ref_images_encoded.append(self._encode_image_to_url(image)) | ||
|
||
@property | ||
def num_reference_images(self) -> int: | ||
"""Get the number of reference images.""" | ||
return len(self._ref_images_encoded) | ||
|
||
def predict(self, image: str | Path, prompt: Prompt) -> str: | ||
"""Predict the anomaly label.""" | ||
image_encoded = self._encode_image_to_url(image) | ||
messages = [] | ||
|
||
# few-shot | ||
if len(self._ref_images_encoded) > 0: | ||
messages.append(self._generate_message(content=prompt.few_shot, images=self._ref_images_encoded)) | ||
|
||
messages.append(self._generate_message(content=prompt.predict, images=[image_encoded])) | ||
|
||
response: ChatCompletion = self.client.chat.completions.create(messages=messages, model=self.model_name) | ||
return response.choices[0].message.content | ||
|
||
@staticmethod | ||
def _generate_message(content: str, images: list[str] | None) -> dict: | ||
"""Generate a message.""" | ||
message: dict[str, list[dict] | str] = {"role": "user"} | ||
if images is not None: | ||
_content: list[dict[str, str | dict]] = [{"type": "text", "text": content}] | ||
_content.extend([{"type": "image_url", "image_url": {"url": image}} for image in images]) | ||
message["content"] = _content | ||
else: | ||
message["content"] = content | ||
return message | ||
|
||
def _encode_image_to_url(self, image: str | Path) -> str: | ||
"""Encode the image to base64 and embed in url string.""" | ||
image_path = Path(image) | ||
extension = image_path.suffix | ||
base64_encoded = self._encode_image_to_base_64(image_path) | ||
return f"data:image/{extension};base64,{base64_encoded}" | ||
|
||
@staticmethod | ||
def _encode_image_to_base_64(image: str | Path) -> str: | ||
"""Encode the image to base64.""" | ||
image = Path(image) | ||
return base64.b64encode(image.read_bytes()).decode("utf-8") | ||
|
||
def _get_api_key(self, api_key: str | None = None) -> str: | ||
if api_key is None: | ||
load_dotenv() | ||
api_key = os.getenv("OPENAI_API_KEY") | ||
if api_key is None: | ||
msg = ( | ||
f"OpenAI API key must be provided to use {self.model_name}." | ||
" Please provide the API key in the constructor, or set the OPENAI_API_KEY environment variable" | ||
" or in a `.env` file." | ||
) | ||
raise ValueError(msg) | ||
return api_key |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
"""Huggingface backend.""" | ||
|
||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import logging | ||
from pathlib import Path | ||
|
||
from lightning_utilities.core.imports import package_available | ||
from PIL import Image | ||
from transformers.modeling_utils import PreTrainedModel | ||
|
||
from anomalib.models.image.vlm_ad.utils import Prompt | ||
|
||
from .base import Backend | ||
|
||
if package_available("transformers"): | ||
import transformers | ||
from transformers.modeling_utils import PreTrainedModel | ||
from transformers.processing_utils import ProcessorMixin | ||
else: | ||
transformers = None | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class Huggingface(Backend): | ||
"""Huggingface backend.""" | ||
|
||
def __init__( | ||
self, | ||
model_name: str, | ||
) -> None: | ||
"""Initialize the Huggingface backend.""" | ||
self.model_name: str = model_name | ||
self._ref_images: list[str] = [] | ||
self._processor: ProcessorMixin | None = None | ||
self._model: PreTrainedModel | None = None | ||
|
||
@property | ||
def processor(self) -> ProcessorMixin: | ||
"""Get the Huggingface processor.""" | ||
if self._processor is None: | ||
if transformers is None: | ||
msg = "transformers is not installed." | ||
raise ValueError(msg) | ||
self._processor = transformers.LlavaNextProcessor.from_pretrained(self.model_name) | ||
return self._processor | ||
|
||
@property | ||
def model(self) -> PreTrainedModel: | ||
"""Get the Huggingface model.""" | ||
if self._model is None: | ||
if transformers is None: | ||
msg = "transformers is not installed." | ||
raise ValueError(msg) | ||
self._model = transformers.LlavaNextForConditionalGeneration.from_pretrained(self.model_name) | ||
return self._model | ||
|
||
@staticmethod | ||
def _generate_message(content: str, images: list[str] | None) -> dict: | ||
"""Generate a message.""" | ||
message: dict[str, str | list[dict]] = {"role": "user"} | ||
_content: list[dict[str, str]] = [{"type": "text", "text": content}] | ||
if images is not None: | ||
_content.extend([{"type": "image"} for _ in images]) | ||
message["content"] = _content | ||
return message | ||
|
||
def add_reference_images(self, image: str | Path) -> None: | ||
"""Add reference images for k-shot.""" | ||
self._ref_images.append(Image.open(image)) | ||
|
||
@property | ||
def num_reference_images(self) -> int: | ||
"""Get the number of reference images.""" | ||
return len(self._ref_images) | ||
|
||
def predict(self, image_path: str | Path, prompt: Prompt) -> str: | ||
"""Predict the anomaly label.""" | ||
image = Image.open(image_path) | ||
messages: list[dict] = [] | ||
|
||
if len(self._ref_images) > 0: | ||
messages.append(self._generate_message(content=prompt.few_shot, images=self._ref_images)) | ||
|
||
messages.append(self._generate_message(content=prompt.predict, images=[image])) | ||
processed_prompt = [self.processor.apply_chat_template(messages, add_generation_prompt=True)] | ||
|
||
images = [*self._ref_images, image] | ||
inputs = self.processor(images, processed_prompt, return_tensors="pt", padding=True).to(self.model.device) | ||
outputs = self.model.generate(**inputs, max_new_tokens=100) | ||
result = self.processor.decode(outputs[0], skip_special_tokens=True) | ||
print(result) | ||
return result |
Oops, something went wrong.