Skip to content

Commit

Permalink
OpenAI Files & Vector Store Hooks (#39248)
Browse files Browse the repository at this point in the history
* bumping openai version
* adding hooks for files and vector stores
* removing functions to get objects by names
  • Loading branch information
nathadfield authored May 1, 2024
1 parent 42dbcca commit da6c2bc
Show file tree
Hide file tree
Showing 4 changed files with 361 additions and 35 deletions.
178 changes: 153 additions & 25 deletions airflow/providers/openai/hooks/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,22 @@
from __future__ import annotations

from functools import cached_property
from typing import TYPE_CHECKING, Any, Literal
from typing import TYPE_CHECKING, Any, BinaryIO, Literal

from openai import OpenAI

if TYPE_CHECKING:
from openai.types.beta import Assistant, AssistantDeleted, Thread, ThreadDeleted
from openai.types import FileDeleted, FileObject
from openai.types.beta import (
Assistant,
AssistantDeleted,
Thread,
ThreadDeleted,
VectorStore,
VectorStoreDeleted,
)
from openai.types.beta.threads import Message, Run
from openai.types.beta.vector_stores import VectorStoreFile, VectorStoreFileBatch, VectorStoreFileDeleted
from openai.types.chat import (
ChatCompletionAssistantMessageParam,
ChatCompletionFunctionMessageParam,
Expand Down Expand Up @@ -111,7 +120,8 @@ def create_chat_completion(
return response.choices

def create_assistant(self, model: str = "gpt-3.5-turbo", **kwargs: Any) -> Assistant:
"""Create an OpenAI assistant using the given model.
"""
Create an OpenAI assistant using the given model.
:param model: The OpenAI model for the assistant to use.
"""
Expand All @@ -132,27 +142,18 @@ def get_assistants(self, **kwargs: Any) -> list[Assistant]:
assistants = self.conn.beta.assistants.list(**kwargs)
return assistants.data

def get_assistant_by_name(self, assistant_name: str) -> Assistant | None:
"""Get an OpenAI Assistant object for a given name.
:param assistant_name: The name of the assistant to retrieve
"""
response = self.get_assistants()
for assistant in response:
if assistant.name == assistant_name:
return assistant
return None

def modify_assistant(self, assistant_id: str, **kwargs: Any) -> Assistant:
"""Modify an existing Assistant object.
"""
Modify an existing Assistant object.
:param assistant_id: The ID of the assistant to be modified.
"""
assistant = self.conn.beta.assistants.update(assistant_id=assistant_id, **kwargs)
return assistant

def delete_assistant(self, assistant_id: str) -> AssistantDeleted:
"""Delete an OpenAI Assistant for a given ID.
"""
Delete an OpenAI Assistant for a given ID.
:param assistant_id: The ID of the assistant to delete.
"""
Expand All @@ -165,16 +166,18 @@ def create_thread(self, **kwargs: Any) -> Thread:
return thread

def modify_thread(self, thread_id: str, metadata: dict[str, Any]) -> Thread:
"""Modify an existing Thread object.
"""
Modify an existing Thread object.
:param thread_id: The ID of the thread to modify.
:param thread_id: The ID of the thread to modify. Only the metadata can be modified.
:param metadata: Set of 16 key-value pairs that can be attached to an object.
"""
thread = self.conn.beta.threads.update(thread_id=thread_id, metadata=metadata)
return thread

def delete_thread(self, thread_id: str) -> ThreadDeleted:
"""Delete an OpenAI thread for a given thread_id.
"""
Delete an OpenAI thread for a given thread_id.
:param thread_id: The ID of the thread to delete.
"""
Expand All @@ -184,7 +187,8 @@ def delete_thread(self, thread_id: str) -> ThreadDeleted:
def create_message(
self, thread_id: str, role: Literal["user", "assistant"], content: str, **kwargs: Any
) -> Message:
"""Create a message for a given Thread.
"""
Create a message for a given Thread.
:param thread_id: The ID of the thread to create a message for.
:param role: The role of the entity that is creating the message. Allowed values include: 'user', 'assistant'.
Expand All @@ -196,15 +200,17 @@ def create_message(
return thread_message

def get_messages(self, thread_id: str, **kwargs: Any) -> list[Message]:
"""Return a list of messages for a given Thread.
"""
Return a list of messages for a given Thread.
:param thread_id: The ID of the thread the messages belong to.
"""
messages = self.conn.beta.threads.messages.list(thread_id=thread_id, **kwargs)
return messages.data

def modify_message(self, thread_id: str, message_id, **kwargs: Any) -> Message:
"""Modify an existing message for a given Thread.
"""
Modify an existing message for a given Thread.
:param thread_id: The ID of the thread to which this message belongs.
:param message_id: The ID of the message to modify.
Expand All @@ -215,16 +221,31 @@ def modify_message(self, thread_id: str, message_id, **kwargs: Any) -> Message:
return thread_message

def create_run(self, thread_id: str, assistant_id: str, **kwargs: Any) -> Run:
"""Create a run for a given thread and assistant.
"""
Create a run for a given thread and assistant.
:param thread_id: The ID of the thread to run.
:param assistant_id: The ID of the assistant to use to execute this run.
"""
run = self.conn.beta.threads.runs.create(thread_id=thread_id, assistant_id=assistant_id, **kwargs)
return run

def create_run_and_poll(self, thread_id: str, assistant_id: str, **kwargs: Any) -> Run:
"""
Create a run for a given thread and assistant and then polls until completion.
:param thread_id: The ID of the thread to run.
:param assistant_id: The ID of the assistant to use to execute this run.
:return: An OpenAI Run object
"""
run = self.conn.beta.threads.runs.create_and_poll(
thread_id=thread_id, assistant_id=assistant_id, **kwargs
)
return run

def get_run(self, thread_id: str, run_id: str) -> Run:
"""Retrieve a run for a given thread and run.
"""
Retrieve a run for a given thread and run.
:param thread_id: The ID of the thread that was run.
:param run_id: The ID of the run to retrieve.
Expand Down Expand Up @@ -257,11 +278,118 @@ def create_embeddings(
model: str = "text-embedding-ada-002",
**kwargs: Any,
) -> list[float]:
"""Generate embeddings for the given text using the given model.
"""
Generate embeddings for the given text using the given model.
:param text: The text to generate embeddings for.
:param model: The model to use for generating embeddings.
"""
response = self.conn.embeddings.create(model=model, input=text, **kwargs)
embeddings: list[float] = response.data[0].embedding
return embeddings

def upload_file(self, file: str, purpose: Literal["fine-tune", "assistants"]) -> FileObject:
"""
Upload a file that can be used across various endpoints. The size of all the files uploaded by one organization can be up to 100 GB.
:param file: The File object (not file name) to be uploaded.
:param purpose: The intended purpose of the uploaded file. Use "fine-tune" for
Fine-tuning and "assistants" for Assistants and Messages.
"""
with open(file, "rb") as file_stream:
file_object = self.conn.files.create(file=file_stream, purpose=purpose)
return file_object

def get_file(self, file_id: str) -> FileObject:
"""
Return information about a specific file.
:param file_id: The ID of the file to use for this request.
"""
file = self.conn.files.retrieve(file_id=file_id)
return file

def get_files(self) -> list[FileObject]:
"""Return a list of files that belong to the user's organization."""
files = self.conn.files.list()
return files.data

def delete_file(self, file_id: str) -> FileDeleted:
"""
Delete a file.
:param file_id: The ID of the file to be deleted.
"""
response = self.conn.files.delete(file_id=file_id)
return response

def create_vector_store(self, **kwargs: Any) -> VectorStore:
"""Create a vector store."""
vector_store = self.conn.beta.vector_stores.create(**kwargs)
return vector_store

def get_vector_stores(self, **kwargs: Any) -> list[VectorStore]:
"""Return a list of vector stores."""
vector_stores = self.conn.beta.vector_stores.list(**kwargs)
return vector_stores.data

def get_vector_store(self, vector_store_id: str) -> VectorStore:
"""
Retrieve a vector store.
:param vector_store_id: The ID of the vector store to retrieve.
"""
vector_store = self.conn.beta.vector_stores.retrieve(vector_store_id=vector_store_id)
return vector_store

def modify_vector_store(self, vector_store_id: str, **kwargs: Any) -> VectorStore:
"""
Modify a vector store.
:param vector_store_id: The ID of the vector store to modify.
"""
vector_store = self.conn.beta.vector_stores.update(vector_store_id=vector_store_id, **kwargs)
return vector_store

def delete_vector_store(self, vector_store_id: str) -> VectorStoreDeleted:
"""
Delete a vector store.
:param vector_store_id: The ID of the vector store to delete.
"""
response = self.conn.beta.vector_stores.delete(vector_store_id=vector_store_id)
return response

def upload_files_to_vector_store(
self, vector_store_id: str, files: list[BinaryIO]
) -> VectorStoreFileBatch:
"""
Upload files to a vector store and poll until completion.
:param vector_store_id: The ID of the vector store the files are to be uploaded
to.
:param files: A list of binary files to upload.
"""
file_batch = self.conn.beta.vector_stores.file_batches.upload_and_poll(
vector_store_id=vector_store_id, files=files
)
return file_batch

def get_vector_store_files(self, vector_store_id: str) -> list[VectorStoreFile]:
"""
Return a list of vector store files.
:param vector_store_id:
"""
vector_store_files = self.conn.beta.vector_stores.files.list(vector_store_id=vector_store_id)
return vector_store_files.data

def delete_vector_store_file(self, vector_store_id: str, file_id: str) -> VectorStoreFileDeleted:
"""
Delete a vector store file. This will remove the file from the vector store but the file itself will not be deleted. To delete the file, use delete_file.
:param vector_store_id: The ID of the vector store that the file belongs to.
:param file_id: The ID of the file to delete.
"""
response = self.conn.beta.vector_stores.files.delete(vector_store_id=vector_store_id, file_id=file_id)
return response
2 changes: 1 addition & 1 deletion airflow/providers/openai/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ integrations:

dependencies:
- apache-airflow>=2.7.0
- openai[datalib]>=1.16
- openai[datalib]>=1.23

hooks:
- integration-name: OpenAI
Expand Down
2 changes: 1 addition & 1 deletion generated/provider_dependencies.json
Original file line number Diff line number Diff line change
Expand Up @@ -805,7 +805,7 @@
"openai": {
"deps": [
"apache-airflow>=2.7.0",
"openai[datalib]>=1.16"
"openai[datalib]>=1.23"
],
"devel-deps": [],
"cross-providers-deps": [],
Expand Down
Loading

0 comments on commit da6c2bc

Please sign in to comment.