Skip to content

Commit

Permalink
need additional_special_tokens argument for HFLM initializer
Browse files Browse the repository at this point in the history
  • Loading branch information
kumapo committed Oct 22, 2023
1 parent 5e80a0b commit 679f918
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 1 deletion.
1 change: 1 addition & 0 deletions lm_eval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def simple_evaluate(
assert isinstance(model, lm_eval.base.LM)
lm = model

print(f"evaluator: no_cache={no_cache}")
if not no_cache:
lm = lm_eval.base.CachingLM(
lm,
Expand Down
2 changes: 2 additions & 0 deletions lm_eval/models/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def __init__(
load_in_8bit: Optional[bool] = False,
trust_remote_code: Optional[bool] = False,
use_fast: Optional[bool] = True,
additional_special_tokens: Optional[str] = None
):
super().__init__()

Expand Down Expand Up @@ -66,6 +67,7 @@ def __init__(
revision=revision,
trust_remote_code=trust_remote_code,
use_fast=use_fast,
additional_special_tokens=additional_special_tokens
)
self.vocab_size = self.tokenizer.vocab_size

Expand Down
1 change: 1 addition & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def main():
if args.description_dict_path:
with open(args.description_dict_path, "r") as f:
description_dict = json.load(f)
print(f"main: no_cache={args.no_cache}")
results = evaluator.simple_evaluate(
model=args.model,
model_args=args.model_args,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/bin/bash
set -eu
MODEL_ARGS="pretrained=stabilityai/japanese-stablelm-base-alpha-7b,use_fast=False,trust_remote_code=True,device_map=auto,torch_dtype=auto,offload_folder=/tmp"
MODEL_ARGS="pretrained=stabilityai/japanese-stablelm-base-alpha-7b,use_fast=False,trust_remote_code=True,device_map=auto,torch_dtype=auto,load_in_8bit=True,offload_folder=/tmp,tokenizer=novelai/nerdstash-tokenizer-v1,additional_special_tokens=['▁▁']"
TASK="jcommonsenseqa-1.2-0.2"
NUM_FEW_SHOTS="3"
python main.py \
Expand All @@ -9,4 +9,5 @@ python main.py \
--tasks $TASK \
--num_fewshot $NUM_FEW_SHOTS \
--device "cuda" \
--no_cache \
--output_path "models/stablelm/stablelm-ja-base-alpha-7b/result.jcqa-1.2.json"

0 comments on commit 679f918

Please sign in to comment.