diff --git a/shortfin/python/shortfin_apps/sd/components/service.py b/shortfin/python/shortfin_apps/sd/components/service.py index f4071b674..4c4f03d29 100644 --- a/shortfin/python/shortfin_apps/sd/components/service.py +++ b/shortfin/python/shortfin_apps/sd/components/service.py @@ -21,7 +21,7 @@ from .config_struct import ModelParams from .manager import SystemManager -from .messages import InferenceExecRequest, InferencePhase +from .messages import SDXLInferenceExecRequest, InferencePhase from .tokenizer import Tokenizer from .metrics import measure, log_duration_str @@ -221,9 +221,9 @@ def __init__( self.service = service self.meta_fiber = meta_fiber self.worker_index = meta_fiber.worker_idx - self.exec_request: InferenceExecRequest = None + self.exec_request: SDXLInferenceExecRequest = None - def assign_command_buffer(self, request: InferenceExecRequest): + def assign_command_buffer(self, request: SDXLInferenceExecRequest): for cb in self.meta_fiber.command_buffers: if cb.sample.shape[0] == self.exec_request.batch_size: self.exec_request.set_command_buffer(cb)