Skip to content

Commit

Permalink
Decouple tasks from model engines and introduce modalities (jupyterla…
Browse files Browse the repository at this point in the history
…b#15)

* fix example playground config

* re-enable check-release workflow for PRs

* decouple tasks from models
  • Loading branch information
dlqqq authored Mar 15, 2023
1 parent 1b1151c commit c153c9a
Show file tree
Hide file tree
Showing 22 changed files with 222 additions and 188 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/check-release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ name: Check Release
on:
push:
branches: ["*"]
# pull_request:
# branches: ["*"]
pull_request:
branches: ["*"]
release:
types: [published]
schedule:
Expand Down
5 changes: 2 additions & 3 deletions packages/jupyter-ai-chatgpt/jupyter_ai_chatgpt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from ._version import __version__

# expose ChatGptModelEngine on the root module so that it may be declared as an
# entrypoint in `pyproject.toml`
# expose engines and tasks on the module root so that they may be declared as
# entrypoints in `pyproject.toml`
from .engine import ChatGptModelEngine


def _jupyter_labextension_paths():
return [{
"src": "labextension",
Expand Down
47 changes: 7 additions & 40 deletions packages/jupyter-ai-chatgpt/jupyter_ai_chatgpt/engine.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,24 @@
import openai
from traitlets.config import Unicode

from typing import List, Dict
from typing import Dict

from jupyter_ai.engine import BaseModelEngine, DefaultTaskDefinition
from jupyter_ai.engine import BaseModelEngine
from jupyter_ai.models import DescribeTaskResponse

class ChatGptModelEngine(BaseModelEngine):
name = "chatgpt"
input_type = "txt"
output_type = "txt"
id = "chatgpt"
name = "ChatGPT"
modalities = [
"txt2txt"
]

api_key = Unicode(
config=True,
help="OpenAI API key",
allow_none=False
)

def list_default_tasks(self) -> List[DefaultTaskDefinition]:
# Tasks your model engine provides by default.
return [
{
"id": "explain-code",
"name": "Explain code",
"prompt_template": "Explain the following Python 3 code. The first sentence must begin with the phrase \"The code below\".\n{body}",
"insertion_mode": "above"
},
{
"id": "generate-code",
"name": "Generate code",
"prompt_template": "Generate Python 3 code in Markdown according to the following definition.\n{body}",
"insertion_mode": "below"
},
{
"id": "explain-code-in-cells-above",
"name": "Explain code in cells above",
"prompt_template": "Explain the following Python 3 code. The first sentence must begin with the phrase \"The code below\".\n{body}",
"insertion_mode": "above-in-cells"
},
{
"id": "generate-code-in-cells-below",
"name": "Generate code in cells below",
"prompt_template": "Generate Python 3 code in Markdown according to the following definition.\n{body}",
"insertion_mode": "below-in-cells"
},
{
"id": "freeform",
"name": "Freeform prompt",
"prompt_template": "{body}",
"insertion_mode": "below"
}
]

async def execute(self, task: DescribeTaskResponse, prompt_variables: Dict[str, str]):
if "body" not in prompt_variables:
raise Exception("Prompt body must be specified.")
Expand Down
4 changes: 2 additions & 2 deletions packages/jupyter-ai-chatgpt/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ test = [
"pytest-cov"
]

[project.entry-points."jupyter_ai.model_engine_class"]
test = "jupyter_ai_chatgpt:ChatGptModelEngine"
[project.entry-points."jupyter_ai.model_engine_classes"]
ChatGptModelEngine = "jupyter_ai_chatgpt:ChatGptModelEngine"

[tool.hatch.version]
source = "nodejs"
Expand Down
8 changes: 4 additions & 4 deletions packages/jupyter-ai-dalle/jupyter_ai_dalle/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from ._version import __version__

# expose DalleModelEngine on the root module so that it may be declared as an
# entrypoint in `pyproject.toml`
# expose engines and tasks on the module root so that they may be declared as
# entrypoints in `pyproject.toml`
from .engine import DalleModelEngine

from .tasks import tasks

def _jupyter_labextension_paths():
return [{
"src": "labextension",
"dest": "jupyter_ai_dalle"
"dest": "@jupyter-ai/dalle"
}]
34 changes: 7 additions & 27 deletions packages/jupyter-ai-dalle/jupyter_ai_dalle/engine.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,23 @@
from typing import List, Dict
from typing import Dict
from traitlets.config import Unicode
import openai

from jupyter_ai.engine import BaseModelEngine, DefaultTaskDefinition
from jupyter_ai.engine import BaseModelEngine
from jupyter_ai.models import DescribeTaskResponse

class DalleModelEngine(BaseModelEngine):
name = "dalle"
input_type = "txt"
output_type = "img"
id = "dalle"
name = "DALL-E"
modalities = [
"txt2img"
]

api_key = Unicode(
config=True,
help="OpenAI API key",
allow_none=False
)

def list_default_tasks(self) -> List[DefaultTaskDefinition]:
return [
{
"id": "generate-image",
"name": "Generate image below",
"prompt_template": "{body}",
"insertion_mode": "below-in-image"
},
{
"id": "generate-photorealistic-image",
"name": "Generate photorealistic image below",
"prompt_template": "{body} in a photorealistic style",
"insertion_mode": "below-in-image"
},
{
"id": "generate-cartoon-image",
"name": "Generate cartoon image below",
"prompt_template": "{body} in the style of a cartoon",
"insertion_mode": "below-in-image"
}
]

async def execute(self, task: DescribeTaskResponse, prompt_variables: Dict[str, str]) -> str:
if "body" not in prompt_variables:
raise Exception("Prompt body must be specified.")
Expand Down
7 changes: 5 additions & 2 deletions packages/jupyter-ai-dalle/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,11 @@ test = [
"pytest-cov"
]

[project.entry-points."jupyter_ai.model_engine_class"]
test = "jupyter_ai_dalle:DalleModelEngine"
[project.entry-points."jupyter_ai.model_engine_classes"]
DalleModelEngine = "jupyter_ai_dalle:DalleModelEngine"

[project.entry-points."jupyter_ai.default_tasks"]
dalle_default_tasks = "jupyter_ai_dalle:tasks"

[tool.hatch.version]
source = "nodejs"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,11 @@ dependencies = [
"jupyter_ai"
]

[project.entry-points."jupyter_ai.model_engine_class"]
test = "{{ cookiecutter.python_name }}:TestModelEngine"
[project.entry-points."jupyter_ai.model_engine_classes"]
TestModelEngine = "{{ cookiecutter.python_name }}:TestModelEngine"

[project.entry-points."jupyter_ai.default_tasks"]
TestDefaultTasks = "{{ cookiecutter.python_name }}:tasks"

[tool.hatch.version]
source = "nodejs"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from ._version import __version__

# expose TestModelEngine on the root module so that it may be declared as an
# entrypoint in `pyproject.toml`
# expose engines and tasks on the module root so that they may be declared as
# entrypoints in `pyproject.toml`
from .engine import TestModelEngine
from .tasks import tasks


def _jupyter_labextension_paths():
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import List, Dict
from typing import Dict

from jupyter_ai.engine import BaseModelEngine, DefaultTaskDefinition
from jupyter_ai.engine import BaseModelEngine
from jupyter_ai.models import DescribeTaskResponse

class TestModelEngine(BaseModelEngine):
Expand All @@ -18,17 +18,6 @@ class TestModelEngine(BaseModelEngine):
# )
#

def list_default_tasks(self) -> List[DefaultTaskDefinition]:
# Tasks your model engine provides by default.
return [
{
"id": "test",
"name": "Test task",
"prompt_template": "{body}",
"insertion_mode": "test"
}
]

async def execute(self, task: DescribeTaskResponse, prompt_variables: Dict[str, str]):
# Core method that executes a model when provided with a task
# description and a dictionary of prompt variables. For example, to
Expand Down
8 changes: 7 additions & 1 deletion packages/jupyter-ai/jupyter_ai/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
from ._version import __version__
from .extension import AiExtension

# imports to expose entry points. DO NOT REMOVE.
from .engine import GPT3ModelEngine
from .tasks import tasks

# imports to expose types to other AI modules. DO NOT REMOVE.
from .tasks import DefaultTaskDefinition

def _jupyter_labextension_paths():
return [{
"src": "labextension",
"dest": "jupyter_ai"
"dest": "@jupyter-ai/core"
}]

def _jupyter_server_extension_points():
Expand Down
26 changes: 9 additions & 17 deletions packages/jupyter-ai/jupyter_ai/engine.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,30 @@
from abc import abstractmethod, ABC, ABCMeta
from typing import Dict, TypedDict, Literal, List
from typing import Dict
import openai
from traitlets.config import LoggingConfigurable, Unicode
from .task_manager import DescribeTaskResponse

class DefaultTaskDefinition(TypedDict):
id: str
name: str
prompt_template: str
insertion_mode: str

class BaseModelEngineMetaclass(ABCMeta, type(LoggingConfigurable)):
pass

class BaseModelEngine(ABC, LoggingConfigurable, metaclass=BaseModelEngineMetaclass):
id: str
name: str

# these two attributes are currently reserved but unused.
input_type: str
output_type: str

@abstractmethod
def list_default_tasks(self) -> List[DefaultTaskDefinition]:
pass

@abstractmethod
async def execute(self, task: DescribeTaskResponse, prompt_variables: Dict[str, str]):
pass

class GPT3ModelEngine(BaseModelEngine):
name = "gpt3"
input_type = "txt"
output_type = "txt"
id = "gpt3"
name = "GPT-3"
modalities = [
"txt2txt"
]

api_key = Unicode(
config=True,
Expand All @@ -40,9 +35,6 @@ class GPT3ModelEngine(BaseModelEngine):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def list_default_tasks(self) -> List[DefaultTaskDefinition]:
return []

async def execute(self, task: DescribeTaskResponse, prompt_variables: Dict[str, str]):
if "body" not in prompt_variables:
raise Exception("Prompt body must be specified.")
Expand Down
34 changes: 29 additions & 5 deletions packages/jupyter-ai/jupyter_ai/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,14 @@ def ai_engines(self):
return self.settings["ai_engines"]

def initialize_settings(self):
# EP := entry point
eps = entry_points()
model_engine_class_eps = eps.select(group="jupyter_ai.model_engine_class")

## step 1: instantiate model engines and bind them to settings
model_engine_class_eps = eps.select(group="jupyter_ai.model_engine_classes")

if not model_engine_class_eps:
self.log.error("No model engines found for jupyter_ai.model_engine_class group. One or more model engines are required for AI extension to work.")
self.log.error("No model engines found for jupyter_ai.model_engine_classes group. One or more model engines are required for AI extension to work.")
return

for model_engine_class_ep in model_engine_class_eps:
Expand All @@ -39,11 +42,32 @@ def initialize_settings(self):
continue

try:
self.ai_engines[Engine.name] = Engine(config=self.config, log=self.log)
self.ai_engines[Engine.id] = Engine(config=self.config, log=self.log)
except:
self.log.error(f"Unable to instantiate model engine class from entry point `{model_engine_class_ep.name}`.")
continue

self.log.info(f"Registered engine `{Engine.name}`.")
self.log.info(f"Registered engine `{Engine.id}`.")

## step 2: load default tasks and bind them to settings
module_default_tasks_eps = eps.select(group="jupyter_ai.default_tasks")

if not module_default_tasks_eps:
self.settings["ai_default_tasks"] = []
return

default_tasks = []
for module_default_tasks_ep in module_default_tasks_eps:
try:
module_default_tasks = module_default_tasks_ep.load()
except:
self.log.error(f"Unable to load task from entry point `{module_default_tasks_ep.name}`")
continue

default_tasks += module_default_tasks

self.settings["ai_default_tasks"] = default_tasks
self.log.info("Registered all default tasks.")

self.log.info(f"Registered {self.name} server extension")
self.log.info(f"Registered {self.name} server extension")

Loading

0 comments on commit c153c9a

Please sign in to comment.