-
-
Notifications
You must be signed in to change notification settings - Fork 23
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
5cc7f29
commit f38dfb8
Showing
11 changed files
with
370 additions
and
50 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,22 +1,115 @@ | ||
from homeassistant import config_entries | ||
from .const import DOMAIN, CONF_API_KEY | ||
from homeassistant.helpers.selector import selector | ||
from homeassistant.exceptions import ServiceValidationError | ||
from .const import DOMAIN, CONF_API_KEY, CONF_MODE, CONF_IP_ADDRESS, CONF_PORT | ||
import voluptuous as vol | ||
import logging | ||
|
||
_LOGGER = logging.getLogger(__name__) | ||
|
||
|
||
async def validate_mode(user_input: dict): | ||
# check CONF_MODE is not empty | ||
_LOGGER.debug(f"Validating mode: {user_input[CONF_MODE]}") | ||
if not user_input[CONF_MODE]: | ||
raise ServiceValidationError("empty_mode") | ||
|
||
|
||
async def validate_localai(user_input: dict): | ||
# check CONF_IP_ADDRESS is not empty | ||
_LOGGER.debug(f"Validating IP Address: {user_input[CONF_IP_ADDRESS]}") | ||
if not user_input[CONF_IP_ADDRESS]: | ||
raise ServiceValidationError("empty_ip_address") | ||
|
||
# check CONF_PORT is not empty | ||
_LOGGER.debug(f"Validating Port: {user_input[CONF_PORT]}") | ||
if not user_input[CONF_PORT]: | ||
raise ServiceValidationError("empty_port") | ||
|
||
|
||
async def validate_openai(user_input: dict): | ||
# check CONF_API_KEY is not empty | ||
_LOGGER.debug(f"Validating API Key: {user_input[CONF_API_KEY]}") | ||
if not user_input[CONF_API_KEY]: | ||
raise ServiceValidationError("empty_api_key") | ||
|
||
|
||
class gpt4visionConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): | ||
|
||
VERSION = 1 | ||
|
||
async def async_step_user(self, user_input=None): | ||
data_schema = vol.Schema({ | ||
vol.Required(CONF_API_KEY): str | ||
vol.Required(CONF_MODE, default="OpenAI"): selector({ | ||
"select": { | ||
"options": ["OpenAI", "LocalAI"], | ||
"mode": "dropdown", | ||
"sort": True, | ||
"custom_value": False | ||
} | ||
}), | ||
}) | ||
|
||
if user_input is not None: | ||
# Save the API key | ||
return self.async_create_entry(title="GPT4Vision Configuration", data=user_input) | ||
self.init_info = user_input | ||
if user_input[CONF_MODE] == "LocalAI": | ||
_LOGGER.debug("LocalAI selected") | ||
return await self.async_step_localai() | ||
else: | ||
_LOGGER.debug("OpenAI selected") | ||
return await self.async_step_openai() | ||
|
||
return self.async_show_form( | ||
step_id="user", | ||
data_schema=data_schema, | ||
description_placeholders=user_input | ||
) | ||
|
||
|
||
async def async_step_localai(self, user_input=None): | ||
data_schema = vol.Schema({ | ||
vol.Required(CONF_IP_ADDRESS): str, | ||
vol.Required(CONF_PORT, default=8080): int, | ||
}) | ||
|
||
if user_input is not None: | ||
try: | ||
await validate_localai(user_input) | ||
# add the mode to user_input | ||
user_input[CONF_MODE] = self.init_info[CONF_MODE] | ||
return self.async_create_entry(title="GPT4Vision LocalAI", data=user_input) | ||
except ServiceValidationError as e: | ||
return self.async_show_form( | ||
step_id="localai", | ||
data_schema=data_schema, | ||
errors={"base": e} | ||
) | ||
|
||
return self.async_show_form( | ||
step_id="localai", | ||
data_schema=data_schema | ||
) | ||
|
||
|
||
async def async_step_openai(self, user_input=None): | ||
data_schema = vol.Schema({ | ||
vol.Required(CONF_API_KEY): str, | ||
}) | ||
|
||
if user_input is not None: | ||
try: | ||
await validate_openai(user_input) | ||
# add the mode to user_input | ||
user_input[CONF_MODE] = self.init_info[CONF_MODE] | ||
return self.async_create_entry(title="GPT4Vision OpenAI", data=user_input) | ||
except ServiceValidationError as e: | ||
return self.async_show_form( | ||
step_id="openai", | ||
data_schema=data_schema, | ||
errors={"base": e} | ||
) | ||
|
||
return self.async_show_form( | ||
step_id="openai", | ||
data_schema=data_schema | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
"""Send a request to localai API '/v1/chat/completions' endpoint""" | ||
|
||
|
||
import requests | ||
import json | ||
import base64 | ||
|
||
|
||
def localai_analyzer(image_path, message, model): | ||
"""Send a request to localai API '/v1/chat/completions' endpoint | ||
Args: | ||
image_path (string): path where image is stored e.g.: "/config/www/tmp/image.jpg" | ||
message (string): message to be sent to AI model | ||
model (string): GPT model: Default model is gpt-4o | ||
max_tokens (int): Maximum number of tokens used by model. Default is 100. | ||
target_width (int): Resolution (width only) of the image. Example: 1280 for 720p etc. | ||
Returns: | ||
json: response_text | ||
""" | ||
|
||
# Open the image file | ||
with open(image_path, "rb") as image_file: | ||
# Encode the image as base64 | ||
image_base64 = base64.b64encode(image_file.read()).decode("utf-8") | ||
|
||
data = {"model": model, "messages": [{"role": "user", "content": [{"type": "text", "text": message}, | ||
{"type": "image_url", "image_url": {"url": "data:image/jpeg;base64," + image_base64}}]}]} | ||
|
||
# Send a POST request to the localai API | ||
response = requests.post( | ||
"http://localhost:8080/v1/chat/completions", json=data) | ||
|
||
# Check if the request was successful | ||
if response.status_code != 200: | ||
raise Exception( | ||
f"Request failed with status code {response.status_code}") | ||
|
||
# Parse the response as JSON | ||
response_text = json.loads(response.text) | ||
|
||
return response_text | ||
|
||
|
||
print(localai_analyzer("C:/Users/valen/Pictures/Screenshots/test.png", "What is in this image?", "gpt-4-vision-preview")) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.