Skip to content

Commit

Permalink
Add timeout argument to civitai.image.create
Browse files Browse the repository at this point in the history
  • Loading branch information
bryantanjw committed Mar 29, 2024
1 parent 154a040 commit 97c2e8f
Show file tree
Hide file tree
Showing 15 changed files with 20 additions and 180 deletions.
31 changes: 0 additions & 31 deletions .gitlab-ci.yml

This file was deleted.

17 changes: 0 additions & 17 deletions .travis.yml

This file was deleted.

5 changes: 2 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,9 @@ input = {
Run a model:

```python
response = civitai.image.create(input)
response = civitai.image.create(input, timeout=60) # Timeout is optional and is None by default
```

_Note: Jobs timeout after 5 minutes._

### Using Additional Networks

The SDK supports additional networks: LoRA, VAE, Hypernetwork, Textual Inversion, LyCORIS, Checkpoint, and LoCon.
Expand Down Expand Up @@ -111,6 +109,7 @@ response = civitai.image.create(options)

| name | type | description |
| ----------------------- | ------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `timeout` | number \| null | Optional. The maximum time in seconds to poll for the image generation. `None` by default. |
| `model` | string \| null | **Required**. The Civitai model to use for generation. |
| `params.prompt` | string \| null | **Required**. The main prompt for the image generation. |
| `params.negativePrompt` | string \| null | Optional. The negative prompt for the image generation. |
Expand Down
23 changes: 2 additions & 21 deletions civitai_py/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ class Image:
def __init__(self, civitai_py):
self.civitai_py = civitai_py

def create(self, input, wait=False):
def create(self, input, timeout=None):
"""
Submits a new image generation job and optionally waits for its completion.
Expand Down Expand Up @@ -285,6 +285,7 @@ def create(self, input, wait=False):
response = self.civitai_py.jobs_api.v1_consumer_jobs_post(
wait=True,
detailed=False,
_request_timeout=timeout,
job_template_list=job_input
)
except ApiException as e:
Expand All @@ -305,26 +306,6 @@ def create(self, input, wait=False):
}
return modified_response

# Helper methods
# def _poll_for_job_completion(self, token, interval=30, timeout=300):
# """
# Polls the job status until completion or timeout.

# :param token: The token of the job to poll.
# :param interval: The interval (in seconds) between status checks.
# :param timeout: The maximum time (in seconds) to wait for job completion.
# :return: The result of the job if completed, None otherwise.
# """
# start_time = time.time()
# while time.time() - start_time < timeout:
# response = self.civitai_py.jobs.get(token=token)
# if response and response.jobs:
# job = response.jobs[0]
# if job.result and job.result.get("blobUrl"):
# return job.result
# time.sleep(interval)
# raise TimeoutError(f"Job {token} did not complete within {timeout} seconds.")


# Create an instance of Civitai and assign it to the variable 'civitai_py'
civitai_py = Civitai()
Expand Down
4 changes: 2 additions & 2 deletions civitai_py/models/provider_job_queue_position.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class ProviderJobQueuePosition(BaseModel):
default=None, description="The estimated throughput rate of the queue", alias="throughputRate")
worker_id: Optional[StrictStr] = Field(
default=None, description="The id of the worker that this job is queued with", alias="workerId")
estimated_start_duration: Optional[TimeSpan] = Field(
estimated_start_duration: Optional[Any] = Field(
default=None, alias="estimatedStartDuration")
estimated_start_date: Optional[datetime] = Field(
default=None, description="The date before the job is estimated to be started. Null if we do not have an estimate yet", alias="estimatedStartDate")
Expand Down Expand Up @@ -113,7 +113,7 @@ def from_dict(cls, obj: Optional[Dict[str, Any]]) -> Optional[Self]:
"precedingCost": obj.get("precedingCost"),
"throughputRate": obj.get("throughputRate"),
"workerId": obj.get("workerId"),
"estimatedStartDuration": TimeSpan.from_dict(obj["estimatedStartDuration"]) if obj.get("estimatedStartDuration") is not None else None,
"estimatedStartDuration": obj.get("estimatedStartDuration"),
"estimatedStartDate": obj.get("estimatedStartDate")
})
return _obj
21 changes: 4 additions & 17 deletions civitai_py/models/time_span.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# time_span.py
# coding: utf-8

"""
Civitai Orchestration Consumer API
Expand Down Expand Up @@ -119,26 +119,13 @@ def to_dict(self) -> Dict[str, Any]:
return _dict

@classmethod
def from_dict(cls, obj: Optional[Union[Dict[str, Any], str]]) -> Optional[Self]:
"""Create an instance of TimeSpan from a dict or a string"""
def from_dict(cls, obj: Optional[Dict[str, Any]]) -> Optional[Self]:
"""Create an instance of TimeSpan from a dict"""
if obj is None:
return None

if isinstance(obj, str):
# Parse the string representation of TimeSpan
time_parts = obj.split(':')
if len(time_parts) == 3:
hours, minutes, seconds = map(int, time_parts)
obj = {
'hours': hours,
'minutes': minutes,
'seconds': seconds
}
else:
raise ValueError(f"Invalid TimeSpan string format: {obj}")

if not isinstance(obj, dict):
raise ValueError(f"Invalid input type for TimeSpan: {type(obj)}")
return cls.model_validate(obj)

_obj = cls.model_validate({
"ticks": obj.get("ticks"),
Expand Down
15 changes: 5 additions & 10 deletions examples/streamlit_demo/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,15 @@
st.set_page_config(layout="wide")


def generate_image(input_data, wait_for_completion):
def generate_image(input_data):
"""
Generate an image using the Civitai SDK.
:param input_data: The input parameters for the image generation.
:param wait_for_completion: Whether to wait for the job to complete before returning the result.
:return: The job result if wait_for_completion is True, otherwise the job token.
:return: The job result
"""
try:
output = civitai.image.create(input_data, wait=wait_for_completion)
output = civitai.image.create(input_data)
return output
except Exception as e:
st.error(f"Failed to generate image: {str(e)}")
Expand All @@ -37,7 +36,6 @@ def generate_image(input_data, wait_for_completion):
width = st.number_input("Width", min_value=1, max_value=1024, value=768)
height = st.number_input("Height", min_value=1, max_value=1024, value=512)
seed = st.number_input("Seed", value=-1)
wait_for_completion = st.checkbox("Wait for completion", value=True)
submit_button = st.form_submit_button("Generate Image")

if submit_button:
Expand All @@ -55,16 +53,13 @@ def generate_image(input_data, wait_for_completion):
},
}

result = generate_image(input_data, wait_for_completion)
result = generate_image(input_data)

if result and wait_for_completion:
if result:
if 'jobs' in result and result['jobs'][0].get('result'):
with col_image:
st.image(result['jobs'][0]['result']['blobUrl'],
caption="Generated Image", use_column_width=True)
else:
with col_image:
st.error("Failed to retrieve the generated image.")
elif result and not wait_for_completion:
with col_image:
st.success(f"Job submitted. Token: {result.get('token')}")
2 changes: 1 addition & 1 deletion examples/streamlit_demo/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
streamlit
civitai_py==0.1.1
civitai_py
2 changes: 1 addition & 1 deletion examples/text2img.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
"}\n",
"\n",
"# Generate the image\n",
"response = civitai_py.image.create(input_params, wait=True)\n",
"response = civitai.image.create(input_params, wait=True)\n",
"print(\"Response:\", response)"
]
},
Expand Down
57 changes: 0 additions & 57 deletions git_push.sh

This file was deleted.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "civitai-py"
version = "0.1.1"
version = "0.1.0"
description = "Civitai Python SDK"
authors = ["Civitai <hello@civitai.com>"]
license = "MIT"
Expand Down
1 change: 1 addition & 0 deletions test-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ pytest-cov>=2.8.1
pytest-randomly>=3.12.0
mypy>=1.4.1
types-python-dateutil>=2.8.19
unittest
10 changes: 1 addition & 9 deletions tests/test_create_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,7 @@ def test_create_from_text_job(self):
},
}

# Test case when wait=False
# output = civitai.image.create(input, wait=False)
# print("Response (wait=False):", output)

# self.assertIsNotNone(output, "The output should not be None.")
# self.assertIn("token", output, "The output should contain a 'token' key.")

# Test case when wait=True
output = civitai.image.create(input, wait=True)
output = civitai.image.create(input, timeout=2)
print("Response (wait=True):", output)

self.assertIsNotNone(output, "The output should not be None.")
Expand Down
1 change: 0 additions & 1 deletion tests/test_query_jobs.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# test/test_query_jobs.py
import unittest
import civitai_py as civitai
import json


class TestJobQuery(unittest.TestCase):
Expand Down
9 changes: 0 additions & 9 deletions tox.ini

This file was deleted.

0 comments on commit 97c2e8f

Please sign in to comment.