Skip to content

Commit

Permalink
Add inference builder for cpu
Browse files Browse the repository at this point in the history
  • Loading branch information
oelayan7 committed Jun 4, 2024
1 parent aed7204 commit c6f48c9
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 2 deletions.
6 changes: 4 additions & 2 deletions accelerator/cpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,16 +301,18 @@ def get_op_builder(self, class_name):
# is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed
# if successful this also means we're doing a local install and not JIT compile path
from op_builder import __deepspeed__ # noqa: F401 # type: ignore
from op_builder.cpu import CCLCommBuilder, ShareMemCommBuilder, FusedAdamBuilder, CPUAdamBuilder, NotImplementedBuilder
from op_builder.cpu import CCLCommBuilder, ShareMemCommBuilder, FusedAdamBuilder, CPUAdamBuilder, InferenceBuilder, NotImplementedBuilder
except ImportError:
from deepspeed.ops.op_builder.cpu import CCLCommBuilder, ShareMemCommBuilder, FusedAdamBuilder, CPUAdamBuilder, NotImplementedBuilder
from deepspeed.ops.op_builder.cpu import CCLCommBuilder, ShareMemCommBuilder, FusedAdamBuilder, CPUAdamBuilder, InferenceBuilder, NotImplementedBuilder

if class_name == "CCLCommBuilder":
return CCLCommBuilder
elif class_name == "ShareMemCommBuilder":
return ShareMemCommBuilder
elif class_name == "FusedAdamBuilder":
return FusedAdamBuilder
elif class_name == "InferenceBuilder":
return InferenceBuilder
elif class_name == "CPUAdamBuilder":
return CPUAdamBuilder
else:
Expand Down
1 change: 1 addition & 0 deletions op_builder/cpu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@
from .comm import CCLCommBuilder, ShareMemCommBuilder
from .fused_adam import FusedAdamBuilder
from .cpu_adam import CPUAdamBuilder
from .transformer_inference import InferenceBuilder
from .no_impl import NotImplementedBuilder
39 changes: 39 additions & 0 deletions op_builder/cpu/transformer_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import importlib

try:
# is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed
# if successful this also means we're doing a local install and not JIT compile path
from op_builder import __deepspeed__ # noqa: F401
from op_builder.builder import OpBuilder
except ImportError:
from deepspeed.ops.op_builder.builder import OpBuilder


class InferenceBuilder(OpBuilder):
BUILD_VAR = "DS_BUILD_TRANSFORMER_INFERENCE"
NAME = "transformer_inference"

def __init__(self, name=None):
name = self.NAME if name is None else name
super().__init__(name=self.NAME)

def absolute_name(self):
return f"deepspeed.ops.transformer.inference.{self.NAME}_op"

def sources(self):
return []

def load(self, verbose=True):
if self.name in __class__._loaded_ops:
return __class__._loaded_ops[self.name]

from deepspeed.git_version_info import installed_ops # noqa: F401
if installed_ops.get(self.name, False):
op_module = importlib.import_module(self.absolute_name())
__class__._loaded_ops[self.name] = op_module
return op_module

0 comments on commit c6f48c9

Please sign in to comment.