diff --git a/mmdeploy/apis/calibration.py b/mmdeploy/apis/calibration.py index 3aa21f374a..42a72d90a6 100644 --- a/mmdeploy/apis/calibration.py +++ b/mmdeploy/apis/calibration.py @@ -62,7 +62,7 @@ def create_calib_input_data(calib_file: str, dataset = task_processor.build_dataset(dataset_cfg, dataset_type) # patch model - backend = get_backend(deploy_cfg) + backend = get_backend(deploy_cfg).value ir = IR.get(get_ir_config(deploy_cfg)['type']) patched_model = patch_model( model, cfg=deploy_cfg, backend=backend, ir=ir) diff --git a/mmdeploy/core/rewriters/rewriter_manager.py b/mmdeploy/core/rewriters/rewriter_manager.py index de3acaffd2..5e84d723d4 100644 --- a/mmdeploy/core/rewriters/rewriter_manager.py +++ b/mmdeploy/core/rewriters/rewriter_manager.py @@ -48,7 +48,11 @@ def patch_model(model: nn.Module, Examples: >>> from mmdeploy.core import patch_model - >>> patched_model = patch_model(model, cfg=deploy_cfg, backend=backend) + >>> from mmdeploy.utils import Backend, IR + >>> deploy_cfg = {} + >>> backend = Backend.DEFAULT.value + >>> ir = IR.ONNX + >>> patched_model = patch_model(model, deploy_cfg, backend, ir) """ return MODULE_REWRITER.patch_model(model, cfg, backend, ir, recursive, **kwargs)