Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[INFER][LLM] Add the AutoPredictor for inference #9445

Merged
merged 9 commits into from
Dec 3, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
polish
zeroRains committed Nov 27, 2024
commit 1e89250ebe8d08cf4efcab6ea5ba8f78a9ad7604
15 changes: 5 additions & 10 deletions llm/predict/predictor.py
Original file line number Diff line number Diff line change
@@ -58,12 +58,10 @@
)


def get_attention_type(*args):
"""
It must be passed in the follow order.
(block_attn)
"""
def get_attention_type(predictor_args):
count = 0
# It must follow this order
args = predictor_args.block_attn
res = []
for attn_type in args:
if attn_type:
@@ -73,10 +71,7 @@ def get_attention_type(*args):
res.append(False)
if count > 1:
raise ValueError("Only one attention type can be used")
try:
return ATTENTION_TYPE_FOR_PREDICTOR_MAPPING_NAMES[tuple(res)]
except KeyError:
raise ValueError("Unknown attention type")
return ATTENTION_TYPE_FOR_PREDICTOR_MAPPING_NAMES[tuple(res)]


@dataclass
@@ -1315,7 +1310,7 @@ def create_predictor(
# infer/ no infer
if predictor_args.inference_model:
# block/no block
attn_type = get_attention_type(predictor_args.block_attn)
attn_type = get_attention_type(predictor_args)
inference_mode = f"{attn_type}Inference"

if predictor_args.mode == "static":
15 changes: 5 additions & 10 deletions paddlenlp/transformers/auto/modeling.py
Original file line number Diff line number Diff line change
@@ -170,12 +170,10 @@
)


def get_attention_type(*args):
"""
It must be passed in the follow order.
(block_attn, speculate_attn)
"""
def get_attention_type(predictor_args):
count = 0
# It must follow this order
args = (predictor_args.block_attn, predictor_args.speculate_attn)
res = []
for attn_type in args:
if attn_type:
@@ -185,10 +183,7 @@ def get_attention_type(*args):
res.append(False)
if count > 1:
raise ValueError("Only one attention type can be True")
try:
return ATTENTION_TYPE_FOR_MODEL_MAPPING_NAMES[tuple(res)]
except KeyError:
raise ValueError("Unknown attention type")
return ATTENTION_TYPE_FOR_MODEL_MAPPING_NAMES[tuple(res)]


def get_name_mapping(task="Model"):
@@ -860,7 +855,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
)
else:
# Check whether the model use block attention
attn_type = get_attention_type(predictor_args.block_attn, predictor_args.speculate_attn)
attn_type = get_attention_type(predictor_args)
model_name = f"{config.architectures[0]}{attn_type}"

# Import the InferenceModel