diff --git a/QEfficient/base/onnx_transforms.py b/QEfficient/base/onnx_transforms.py index 945850c50..bdf7bf677 100644 --- a/QEfficient/base/onnx_transforms.py +++ b/QEfficient/base/onnx_transforms.py @@ -19,16 +19,20 @@ from QEfficient.customop.ctx_scatter_gather import ( CtxGather, CtxGather3D, + CtxGatherBlockedKV, CtxGatherFunc, CtxGatherFunc3D, + CtxGatherFuncBlockedKV, CtxScatter, CtxScatter3D, CtxScatterFunc, CtxScatterFunc3D, ) from QEfficient.customop.ctx_scatter_gather_cb import ( + CtxGatherBlockedKVCB, CtxGatherCB, CtxGatherCB3D, + CtxGatherFuncBlockedKVCB, CtxGatherFuncCB, CtxGatherFuncCB3D, CtxScatterCB, @@ -95,6 +99,8 @@ class CustomOpTransform(BaseOnnxTransform): "CtxScatterFuncCB3D": (CtxScatterFuncCB3D, CtxScatterCB3D), "CtxGatherFuncCB": (CtxGatherFuncCB, CtxGatherCB), "CtxGatherFuncCB3D": (CtxGatherFuncCB3D, CtxGatherCB3D), + "CtxGatherFuncBlockedKV": (CtxGatherFuncBlockedKV, CtxGatherBlockedKV), + "CtxGatherFuncBlockedKVCB": (CtxGatherFuncBlockedKVCB, CtxGatherBlockedKVCB), } @classmethod