Skip to content

Commit

Permalink
add 图片擦除&局部重绘 (#42)
Browse files Browse the repository at this point in the history
图片擦除&局部重绘
  • Loading branch information
niehongxu116 authored Dec 24, 2024
1 parent 950bb40 commit 37275b9
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 0 deletions.
21 changes: 21 additions & 0 deletions client/py/yidong/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
EditorConfig,
ImageGenerationTask,
ImageGenerationTaskResult,
ImageInpaintTask,
ImageInpaintTaskResult,
ImageRemoveTask,
ImageRemoveTaskResult,
Pagination,
PingTask,
PingTaskResult,
Expand Down Expand Up @@ -417,6 +421,23 @@ def image_generation(
"""Generate images based on the given prompt or the reference image."""
return self._submit_task(locals())

def image_inpaint(
self,
image_id: str,
mask_base64: str,
prompt: str | None = None,
) -> TaskRef[ImageInpaintTask, ImageInpaintTaskResult]:
"""image inpaint based on the mask image base64 string and the given prompt."""
return self._submit_task(locals())

def image_remove(
self,
image_id: str,
mask_base64: str,
) -> TaskRef[ImageRemoveTask, ImageRemoveTaskResult]:
"""image remove based on the mask image base64 string."""
return self._submit_task(locals())


def main():
rich.print(CLI(YiDong))
Expand Down
27 changes: 27 additions & 0 deletions client/py/yidong/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,29 @@ class ImageGenerationTaskResult(BaseModel):
generated_image_ids: list[str]


class ImageInpaintTask(BaseModel):
type: Literal["image_inpaint"] = "image_inpaint"
image_id: str = Field(..., min_length=1)
mask_base64: str = Field(..., min_length=1)
prompt: str | None = None


class ImageInpaintTaskResult(BaseModel):
type: Literal["image_inpaint"] = "image_inpaint"
generated_image_ids: list[str]


class ImageRemoveTask(BaseModel):
type: Literal["image_remove"] = "image_remove"
image_id: str = Field(..., min_length=1)
mask_base64: str = Field(..., min_length=1)


class ImageRemoveTaskResult(BaseModel):
type: Literal["image_remove"] = "image_remove"
generated_image_ids: list[str]


Task = Annotated[
Union[
PingTask,
Expand All @@ -336,6 +359,8 @@ class ImageGenerationTaskResult(BaseModel):
VideoConcatTask,
VideoSnapshotTask,
ImageGenerationTask,
ImageInpaintTask,
ImageRemoveTask,
],
Field(discriminator="type"),
]
Expand All @@ -351,6 +376,8 @@ class ImageGenerationTaskResult(BaseModel):
VideoConcatTaskResult,
VideoSnapshotTaskResult,
ImageGenerationTaskResult,
ImageInpaintTaskResult,
ImageRemoveTaskResult,
],
Field(discriminator="type"),
]
Expand Down

0 comments on commit 37275b9

Please sign in to comment.