Skip to content

Commit

Permalink
Add seqlen
Browse files Browse the repository at this point in the history
  • Loading branch information
elvircrn committed Dec 3, 2024
1 parent 76ec14a commit c29de80
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
8 changes: 7 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,12 @@ def perplexity_eval(model, testenc, args, dev):
choices=["auto", "float16", "float32"],
help="dtype to load the model.",
)
parser.add_argument(
"--model_seqlen",
type=int,
default=4096,
help="Model seqlen and calibration data context length.",
)

args = parser.parse_args()

Expand Down Expand Up @@ -574,7 +580,7 @@ def perplexity_eval(model, testenc, args, dev):
# device = "cpu"

print("============ Loading model... ============")
model = get_model(args.model_path, args.load, args.dtype).train(False)
model = get_model(args.model_path, args.load, args.dtype, args.model_seqlen).train(False)

model = model.to(device=device)
print("\n============ Quantizing model... ============")
Expand Down
5 changes: 3 additions & 2 deletions modelutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def skip(*args, **kwargs):
torch.nn.init.kaiming_uniform_, torch.nn.init.uniform_, torch.nn.init.normal_ = saved_inits # restoring


def get_model(model_path, load_quantized=None, dtype="auto"):
def get_model(model_path, load_quantized=None, dtype="auto", seqlen=4096):
if dtype == "auto":
dtype = (
AutoConfig.from_pretrained(model_path, trust_remote_code=True).torch_dtype or "auto"
Expand All @@ -50,7 +50,8 @@ def get_model(model_path, load_quantized=None, dtype="auto"):
config=config
# local_files_only=True
)
model.seqlen = 2048

model.seqlen = seqlen

print("Model loaded sucessfully ...")

Expand Down

0 comments on commit c29de80

Please sign in to comment.