Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OpenAI Files & Vector Store Hooks #39248

Merged
merged 9 commits into from
May 1, 2024
178 changes: 153 additions & 25 deletions airflow/providers/openai/hooks/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,22 @@
from __future__ import annotations

from functools import cached_property
from typing import TYPE_CHECKING, Any, Literal
from typing import TYPE_CHECKING, Any, BinaryIO, Literal

from openai import OpenAI

if TYPE_CHECKING:
from openai.types.beta import Assistant, AssistantDeleted, Thread, ThreadDeleted
from openai.types import FileDeleted, FileObject
from openai.types.beta import (
Assistant,
AssistantDeleted,
Thread,
ThreadDeleted,
VectorStore,
VectorStoreDeleted,
)
from openai.types.beta.threads import Message, Run
from openai.types.beta.vector_stores import VectorStoreFile, VectorStoreFileBatch, VectorStoreFileDeleted
from openai.types.chat import (
ChatCompletionAssistantMessageParam,
ChatCompletionFunctionMessageParam,
Expand Down Expand Up @@ -111,7 +120,8 @@ def create_chat_completion(
return response.choices

def create_assistant(self, model: str = "gpt-3.5-turbo", **kwargs: Any) -> Assistant:
"""Create an OpenAI assistant using the given model.
"""
Create an OpenAI assistant using the given model.

:param model: The OpenAI model for the assistant to use.
"""
Expand All @@ -132,27 +142,18 @@ def get_assistants(self, **kwargs: Any) -> list[Assistant]:
assistants = self.conn.beta.assistants.list(**kwargs)
return assistants.data

def get_assistant_by_name(self, assistant_name: str) -> Assistant | None:
"""Get an OpenAI Assistant object for a given name.

:param assistant_name: The name of the assistant to retrieve
"""
response = self.get_assistants()
for assistant in response:
if assistant.name == assistant_name:
return assistant
return None

def modify_assistant(self, assistant_id: str, **kwargs: Any) -> Assistant:
"""Modify an existing Assistant object.
"""
Modify an existing Assistant object.

:param assistant_id: The ID of the assistant to be modified.
"""
assistant = self.conn.beta.assistants.update(assistant_id=assistant_id, **kwargs)
return assistant

def delete_assistant(self, assistant_id: str) -> AssistantDeleted:
"""Delete an OpenAI Assistant for a given ID.
"""
Delete an OpenAI Assistant for a given ID.

:param assistant_id: The ID of the assistant to delete.
"""
Expand All @@ -165,16 +166,18 @@ def create_thread(self, **kwargs: Any) -> Thread:
return thread

def modify_thread(self, thread_id: str, metadata: dict[str, Any]) -> Thread:
"""Modify an existing Thread object.
"""
Modify an existing Thread object.

:param thread_id: The ID of the thread to modify.
:param thread_id: The ID of the thread to modify. Only the metadata can be modified.
:param metadata: Set of 16 key-value pairs that can be attached to an object.
"""
thread = self.conn.beta.threads.update(thread_id=thread_id, metadata=metadata)
return thread

def delete_thread(self, thread_id: str) -> ThreadDeleted:
"""Delete an OpenAI thread for a given thread_id.
"""
Delete an OpenAI thread for a given thread_id.

:param thread_id: The ID of the thread to delete.
"""
Expand All @@ -184,7 +187,8 @@ def delete_thread(self, thread_id: str) -> ThreadDeleted:
def create_message(
self, thread_id: str, role: Literal["user", "assistant"], content: str, **kwargs: Any
) -> Message:
"""Create a message for a given Thread.
"""
Create a message for a given Thread.

:param thread_id: The ID of the thread to create a message for.
:param role: The role of the entity that is creating the message. Allowed values include: 'user', 'assistant'.
Expand All @@ -196,15 +200,17 @@ def create_message(
return thread_message

def get_messages(self, thread_id: str, **kwargs: Any) -> list[Message]:
"""Return a list of messages for a given Thread.
"""
Return a list of messages for a given Thread.

:param thread_id: The ID of the thread the messages belong to.
"""
messages = self.conn.beta.threads.messages.list(thread_id=thread_id, **kwargs)
return messages.data

def modify_message(self, thread_id: str, message_id, **kwargs: Any) -> Message:
"""Modify an existing message for a given Thread.
"""
Modify an existing message for a given Thread.

:param thread_id: The ID of the thread to which this message belongs.
:param message_id: The ID of the message to modify.
Expand All @@ -215,16 +221,31 @@ def modify_message(self, thread_id: str, message_id, **kwargs: Any) -> Message:
return thread_message

def create_run(self, thread_id: str, assistant_id: str, **kwargs: Any) -> Run:
"""Create a run for a given thread and assistant.
"""
Create a run for a given thread and assistant.

:param thread_id: The ID of the thread to run.
:param assistant_id: The ID of the assistant to use to execute this run.
"""
run = self.conn.beta.threads.runs.create(thread_id=thread_id, assistant_id=assistant_id, **kwargs)
return run

def create_run_and_poll(self, thread_id: str, assistant_id: str, **kwargs: Any) -> Run:
"""
Create a run for a given thread and assistant and then polls until completion.

:param thread_id: The ID of the thread to run.
:param assistant_id: The ID of the assistant to use to execute this run.
:return: An OpenAI Run object
"""
run = self.conn.beta.threads.runs.create_and_poll(
thread_id=thread_id, assistant_id=assistant_id, **kwargs
)
return run

def get_run(self, thread_id: str, run_id: str) -> Run:
"""Retrieve a run for a given thread and run.
"""
Retrieve a run for a given thread and run.

:param thread_id: The ID of the thread that was run.
:param run_id: The ID of the run to retrieve.
Expand Down Expand Up @@ -257,11 +278,118 @@ def create_embeddings(
model: str = "text-embedding-ada-002",
**kwargs: Any,
) -> list[float]:
"""Generate embeddings for the given text using the given model.
"""
Generate embeddings for the given text using the given model.

:param text: The text to generate embeddings for.
:param model: The model to use for generating embeddings.
"""
response = self.conn.embeddings.create(model=model, input=text, **kwargs)
embeddings: list[float] = response.data[0].embedding
return embeddings

def upload_file(self, file: str, purpose: Literal["fine-tune", "assistants"]) -> FileObject:
"""
Upload a file that can be used across various endpoints. The size of all the files uploaded by one organization can be up to 100 GB.

:param file: The File object (not file name) to be uploaded.
:param purpose: The intended purpose of the uploaded file. Use "fine-tune" for
Fine-tuning and "assistants" for Assistants and Messages.
"""
with open(file, "rb") as file_stream:
file_object = self.conn.files.create(file=file_stream, purpose=purpose)
return file_object

def get_file(self, file_id: str) -> FileObject:
"""
Return information about a specific file.

:param file_id: The ID of the file to use for this request.
"""
file = self.conn.files.retrieve(file_id=file_id)
return file

def get_files(self) -> list[FileObject]:
"""Return a list of files that belong to the user's organization."""
files = self.conn.files.list()
return files.data

def delete_file(self, file_id: str) -> FileDeleted:
"""
Delete a file.

:param file_id: The ID of the file to be deleted.
"""
response = self.conn.files.delete(file_id=file_id)
return response

def create_vector_store(self, **kwargs: Any) -> VectorStore:
"""Create a vector store."""
vector_store = self.conn.beta.vector_stores.create(**kwargs)
return vector_store

def get_vector_stores(self, **kwargs: Any) -> list[VectorStore]:
"""Return a list of vector stores."""
vector_stores = self.conn.beta.vector_stores.list(**kwargs)
return vector_stores.data

def get_vector_store(self, vector_store_id: str) -> VectorStore:
"""
Retrieve a vector store.

:param vector_store_id: The ID of the vector store to retrieve.
"""
vector_store = self.conn.beta.vector_stores.retrieve(vector_store_id=vector_store_id)
return vector_store

def modify_vector_store(self, vector_store_id: str, **kwargs: Any) -> VectorStore:
"""
Modify a vector store.

:param vector_store_id: The ID of the vector store to modify.
"""
vector_store = self.conn.beta.vector_stores.update(vector_store_id=vector_store_id, **kwargs)
return vector_store

def delete_vector_store(self, vector_store_id: str) -> VectorStoreDeleted:
"""
Delete a vector store.

:param vector_store_id: The ID of the vector store to delete.
"""
response = self.conn.beta.vector_stores.delete(vector_store_id=vector_store_id)
return response

def upload_files_to_vector_store(
self, vector_store_id: str, files: list[BinaryIO]
) -> VectorStoreFileBatch:
"""
Upload files to a vector store and poll until completion.

:param vector_store_id: The ID of the vector store the files are to be uploaded
to.
:param files: A list of binary files to upload.
"""
file_batch = self.conn.beta.vector_stores.file_batches.upload_and_poll(
vector_store_id=vector_store_id, files=files
)
return file_batch

def get_vector_store_files(self, vector_store_id: str) -> list[VectorStoreFile]:
"""
Return a list of vector store files.

:param vector_store_id:
"""
vector_store_files = self.conn.beta.vector_stores.files.list(vector_store_id=vector_store_id)
return vector_store_files.data

def delete_vector_store_file(self, vector_store_id: str, file_id: str) -> VectorStoreFileDeleted:
"""
Delete a vector store file. This will remove the file from the vector store but the file itself will not be deleted. To delete the file, use delete_file.

:param vector_store_id: The ID of the vector store that the file belongs to.
:param file_id: The ID of the file to delete.
"""
response = self.conn.beta.vector_stores.files.delete(vector_store_id=vector_store_id, file_id=file_id)
return response
2 changes: 1 addition & 1 deletion airflow/providers/openai/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ integrations:

dependencies:
- apache-airflow>=2.7.0
- openai[datalib]>=1.16
- openai[datalib]>=1.23

hooks:
- integration-name: OpenAI
Expand Down
2 changes: 1 addition & 1 deletion generated/provider_dependencies.json
Original file line number Diff line number Diff line change
Expand Up @@ -805,7 +805,7 @@
"openai": {
"deps": [
"apache-airflow>=2.7.0",
"openai[datalib]>=1.16"
"openai[datalib]>=1.23"
],
"devel-deps": [],
"cross-providers-deps": [],
Expand Down
Loading