Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update ollama plugin to reflect API changes #3065

Draft
wants to merge 10 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 4 additions & 9 deletions plugins/flytekit-inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,6 @@ The Ollama plugin allows you to serve LLMs locally.
You can either pull an existing model or create a new one.

```python
from textwrap import dedent

from flytekit import ImageSpec, Resources, task, workflow
from flytekitplugins.inference import Ollama, Model
from flytekit.extras.accelerators import A10G
Expand All @@ -91,13 +89,10 @@ image = ImageSpec(
ollama_instance = Ollama(
model=Model(
name="llama3-mario",
modelfile=dedent("""\
FROM llama3
ADAPTER {inputs.gguf}
PARAMETER temperature 1
PARAMETER num_ctx 4096
SYSTEM You are Mario from super mario bros, acting as an assistant.\
"""),
from_="llama3",
adapters=["gguf"],
parameters={"temperature": 1, "num_ctx": 4096},
system="You are Mario from super mario bros, acting as an assistant."
)
)

Expand Down
126 changes: 91 additions & 35 deletions plugins/flytekit-inference/flytekitplugins/inference/ollama/serve.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,43 @@
import base64
from dataclasses import dataclass
from typing import Optional
from typing import Any, Mapping, Optional, Sequence, Union

from ..sidecar_template import ModelInferenceTemplate


NEWLINE = "\n"
NEWLINE_ESCAPED = "\\n"


@dataclass
class Model:
"""Represents the configuration for a model used in a Kubernetes pod template.

:param name: The name of the model.
:param mem: The amount of memory allocated for the model, specified as a string. Default is "500Mi".
:param cpu: The number of CPU cores allocated for the model. Default is 1.
:param modelfile: The actual model file as a JSON-serializable string. This represents the file content. Default is `None` if not applicable.
:param from: The name of an existing model used as a base to create a new custom model.
:param files: A list of file names to create the model from.
:param adapters: A list of file names to create the model for LORA adapters.
:param template: The prompt template for the model.
:param license: A string or list of strings containing the license or licenses for the model.
:param system: A string containing the system prompt for the model.
:param parameters: A dictionary of parameters for the model.
:param messages: A list of message objects used to create a conversation.
:param quantize: Quantize a non-quantized (e.g. float16) model.
"""

name: str
mem: str = "500Mi"
cpu: int = 1
modelfile: Optional[str] = None
from_: Optional[str] = None
files: Optional[list[str]] = None
adapters: Optional[list[str]] = None
template: Optional[str] = None
license: Optional[Union[str, list[str]]] = None
system: Optional[str] = None
parameters: Optional[Mapping[str, Any]] = None
messages: Optional[Sequence[Mapping[str, Any]]] = None
quantize: Optional[str] = None


class Ollama(ModelInferenceTemplate):
Expand All @@ -36,7 +55,10 @@ def __init__(
):
"""Initialize Ollama class for managing a Kubernetes pod template.

:param model: An instance of the Model class containing the model's configuration, including its name, memory, CPU, and file.
Python 3.12 or higher is required due to support for backslashes in f-strings:
https://realpython.com/python312-f-strings/#backslashes-now-allowed-in-f-strings

Comment on lines +58 to +59
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think using the NEWLINE and NEWLINE_ESCAPED global vars should address this

:param model: An instance of the Model class containing the model's configuration, including its name, memory, CPU, and the modelfile parameters.
:param image: The Docker image to be used for the container. Default is "ollama/ollama".
:param port: The port number on which the container should expose its service. Default is 11434.
:param cpu: The number of CPU cores requested for the container. Default is 1.
Expand All @@ -48,7 +70,15 @@ def __init__(
self._model_name = model.name
self._model_mem = model.mem
self._model_cpu = model.cpu
self._model_modelfile = model.modelfile
self._model_from = model.from_
self._model_files = model.files
self._model_adapters = model.adapters
self._model_template = model.template
self._model_license = model.license
self._model_system = model.system
self._model_parameters = model.parameters
self._model_messages = model.messages
self._model_quantize = model.quantize

super().__init__(
image=image,
Expand All @@ -58,7 +88,7 @@ def __init__(
mem=mem,
download_inputs_mem=download_inputs_mem,
download_inputs_cpu=download_inputs_cpu,
download_inputs=(True if self._model_modelfile and "{inputs" in self._model_modelfile else False),
download_inputs=bool(self._model_adapters or self._model_files),
)

self.setup_ollama_pod_template()
Expand All @@ -71,7 +101,19 @@ def setup_ollama_pod_template(self):
V1VolumeMount,
)

container_name = "create-model" if self._model_modelfile else "pull-model"
custom_model = any(
[
self._model_files,
self._model_adapters,
self._model_template,
self._model_license,
self._model_system,
self._model_parameters,
self._model_messages,
self._model_quantize,
]
)
container_name = "create-model" if custom_model else "pull-model"

base_code = """
import base64
Expand All @@ -97,53 +139,64 @@ def setup_ollama_pod_template(self):
print('Ollama service did not become ready in time')
exit(1)
"""
if self._model_modelfile:
encoded_modelfile = base64.b64encode(self._model_modelfile.encode("utf-8")).decode("utf-8")

if "{inputs" in self._model_modelfile:
if custom_model:
if self._model_files or self._model_adapters:
python_code = f"""
{base_code}
import json
from ollama._client import Client

with open('/shared/inputs.json', 'r') as f:
inputs = json.load(f)

class AttrDict(dict):
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self

inputs = {{'inputs': AttrDict(inputs)}}
files = {{}}
adapters = {{}}
client = Client('{self.base_url}')

encoded_model_file = '{encoded_modelfile}'

modelfile = base64.b64decode(encoded_model_file).decode('utf-8').format(**inputs)
modelfile = modelfile.replace('{{', '{{{{').replace('}}', '}}}}')

with open('Modelfile', 'w') as f:
f.write(modelfile)
for input_name, input_value in inputs.items():
if {self._model_files} and input_name in {self._model_files}:
files[input_name] = client.create_blob(input_value)
elif {self._model_adapters} and input_name in {self._model_adapters}:
adapters[input_name] = client.create_blob(input_value)

{ollama_service_ready}

# Debugging: Shows the status of model creation.
for chunk in ollama.create(model='{self._model_name}', path='Modelfile', stream=True):
for chunk in ollama.create(
model={"'" + self._model_name + "'" if self._model_name else None},
from_={"'" + self._model_from + "'" if self._model_from else None},
files=files if files else None,
adapters=adapters if adapters else None,
template={"'" + self._model_template.replace(NEWLINE, NEWLINE_ESCAPED) + "'" if self._model_template else None},
license={"'" + self._model_license + "'" if self._model_license else None},
system={"'" + self._model_system.replace(NEWLINE, NEWLINE_ESCAPED) + "'" if self._model_system else None},
parameters={self._model_parameters if self._model_parameters else None},
messages={self._model_messages if self._model_messages else None},
quantize={"'" + self._model_quantize + "'" if self._model_quantize else None},
stream=True
):
print(chunk)
"""
else:
python_code = f"""
{base_code}

encoded_model_file = '{encoded_modelfile}'

modelfile = base64.b64decode(encoded_model_file).decode('utf-8')

with open('Modelfile', 'w') as f:
f.write(modelfile)

{ollama_service_ready}

# Debugging: Shows the status of model creation.
for chunk in ollama.create(model='{self._model_name}', path='Modelfile', stream=True):
for chunk in ollama.create(
model={"'" + self._model_name + "'" if self._model_name else None},
from_={"'" + self._model_from + "'" if self._model_from else None},
files=None,
adapters=None,
template={"'" + self._model_template.replace(NEWLINE, NEWLINE_ESCAPED) + "'" if self._model_template else None},
license={"'" + self._model_license + "'" if self._model_license else None},
system={"'" + self._model_system.replace(NEWLINE, NEWLINE_ESCAPED) + "'" if self._model_system else None},
parameters={self._model_parameters if self._model_parameters else None},
messages={self._model_messages if self._model_messages else None},
quantize={"'" + self._model_quantize + "'" if self._model_quantize else None},
stream=True
):
print(chunk)
"""
else:
Expand All @@ -164,7 +217,10 @@ def __init__(self, *args, **kwargs):
name=container_name,
image="python:3.11-slim",
command=["/bin/sh", "-c"],
args=[f"pip install requests && pip install ollama==0.3.3 && {command}"],
args=[
"apt-get update && apt-get install -y git && "
f"pip install requests && pip install git+https://github.com/ollama/ollama-python.git@eefe5c9666e2fa82ab17618155dd0aae47bba8fa && {command}"
],
resources=V1ResourceRequirements(
requests={
"cpu": self._model_cpu,
Expand Down
Loading