Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
gptq_marlin: Fix bug report vllm-project#5088 (comment)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexm-redhat committed May 29, 2024
1 parent 594392d commit a44225e
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 9 deletions.
2 changes: 1 addition & 1 deletion examples/offline_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

# Create an LLM.
llm = LLM(model="facebook/opt-125m")
llm = LLM(model="TechxGenus/gemma-1.1-2b-it-GPTQ")
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
Expand Down
12 changes: 4 additions & 8 deletions vllm/model_executor/layers/quantization/gptq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,14 +298,10 @@ def create_weights(
},
)

g_idx_sort_indices = Parameter(
torch.empty(
g_idx.shape,
dtype=torch.int32,
),
requires_grad=False,
g_idx_sort_indices = torch.empty(
g_idx.shape,
dtype=torch.int32,
)
set_weight_attrs(g_idx_sort_indices, extra_weight_attrs)

# Scales
scales = Parameter(
Expand Down Expand Up @@ -356,9 +352,9 @@ def create_weights(

layer.register_parameter("qweight", qweight)
layer.register_parameter("g_idx", g_idx)
layer.register_parameter("g_idx_sort_indices", g_idx_sort_indices)
layer.register_parameter("scales", scales)
layer.register_parameter("qzeros", qzeros)
layer.g_idx_sort_indices = g_idx_sort_indices
layer.workspace = workspace
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
Expand Down

0 comments on commit a44225e

Please sign in to comment.