Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Eval recipe not using max_seq_length #1644

Closed
SalmanMohammadi opened this issue Sep 21, 2024 · 5 comments · Fixed by #1773
Closed

[BUG] Eval recipe not using max_seq_length #1644

SalmanMohammadi opened this issue Sep 21, 2024 · 5 comments · Fixed by #1773

Comments

@SalmanMohammadi
Copy link
Collaborator

2024-09-21:20:19:56,843 INFO     [_logging.py:101] Running EleutherEvalRecipe with resolved config:

batch_size: 1
checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: ./target/1b_normal
  checkpoint_files:
  - pytorch_model.bin
  model_type: LLAMA2
  output_dir: ./target/tmp
device: cuda
dtype: fp32
enable_kv_cache: true
limit: null
max_seq_length: 1024
model:
  _component_: torchtune.models.llama2.llama2
  embed_dim: 2048
  max_seq_len: 2048
  norm_eps: 1.0e-05
  num_heads: 32
  num_kv_heads: 4
  num_layers: 22
  vocab_size: 32000
quantizer: null
seed: 1234
tasks:
- mmlu_pro
tokenizer:
  _component_: torchtune.models.llama2.llama2_tokenizer
  path: ./target/1b_normal/tokenizer.model

2024-09-21:20:19:57,698 DEBUG    [seed.py:60] Setting manual seed to local seed 1234. Local seed is seed + rank = 1234 + 0
2024-09-21:20:19:58,522 INFO     [eleuther_eval.py:237] Model is initialized with precision torch.float32.
2024-09-21:20:19:58,532 INFO     [eleuther_eval.py:209] Tokenizer is initialized from file.
2024-09-21:20:19:58,659 INFO     [huggingface.py:130] Using device 'cuda:0'
/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/transformers/tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884
  warnings.warn(
2024-09-21:20:19:58,987 INFO     [huggingface.py:366] Model parallel was set to False, max memory was not set, and device map was set to {'': 'cuda:0'}
2024-09-21:20:20:00,200 INFO     [__init__.py:491] `group` and `group_alias` keys in TaskConfigs are deprecated and will be removed in v0.4.5 of lm_eval. The new `tag` field will be used to allow for a shortcut to a group of tasks one does not wish to aggregate metrics across. `group`s which aggregate across subtasks must be only defined in a separate group config file, which will be the official way to create groups that support cross-task aggregation as in `mmlu`. Please see the v0.4.4 patch notes and our documentation: https://github.com/EleutherAI/lm-evaluation-harness/blob/main/docs/new_task_guide.md#advanced-group-configs for more information.
2024-09-21:20:20:33,959 INFO     [eleuther_eval.py:280] Running evaluation on ['mmlu_pro'] tasks.
2024-09-21:20:20:33,961 INFO     [task.py:423] Building contexts for mmlu_pro_biology on rank 0...
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 717/717 [00:00<00:00, 861.90it/s]
2024-09-21:20:20:34,879 INFO     [task.py:423] Building contexts for mmlu_pro_business on rank 0...
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 789/789 [00:00<00:00, 860.26it/s]
2024-09-21:20:20:35,883 INFO     [task.py:423] Building contexts for mmlu_pro_chemistry on rank 0...
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1132/1132 [00:01<00:00, 816.52it/s]
2024-09-21:20:20:37,392 INFO     [task.py:423] Building contexts for mmlu_pro_computer_science on rank 0...
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 410/410 [00:00<00:00, 843.05it/s]
2024-09-21:20:20:37,929 INFO     [task.py:423] Building contexts for mmlu_pro_economics on rank 0...
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 844/844 [00:01<00:00, 842.82it/s]
2024-09-21:20:20:39,024 INFO     [task.py:423] Building contexts for mmlu_pro_engineering on rank 0...
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 969/969 [00:01<00:00, 871.45it/s]
2024-09-21:20:20:40,249 INFO     [task.py:423] Building contexts for mmlu_pro_health on rank 0...
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 818/818 [00:00<00:00, 862.93it/s]
2024-09-21:20:20:41,289 INFO     [task.py:423] Building contexts for mmlu_pro_history on rank 0...
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 381/381 [00:00<00:00, 824.88it/s]
2024-09-21:20:20:41,798 INFO     [task.py:423] Building contexts for mmlu_pro_law on rank 0...
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1101/1101 [00:01<00:00, 855.97it/s]
2024-09-21:20:20:43,211 INFO     [task.py:423] Building contexts for mmlu_pro_math on rank 0...
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1351/1351 [00:01<00:00, 831.30it/s]
2024-09-21:20:20:44,988 INFO     [task.py:423] Building contexts for mmlu_pro_other on rank 0...
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 924/924 [00:01<00:00, 818.69it/s]
2024-09-21:20:20:46,235 INFO     [task.py:423] Building contexts for mmlu_pro_philosophy on rank 0...
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 499/499 [00:00<00:00, 824.51it/s]
2024-09-21:20:20:46,900 INFO     [task.py:423] Building contexts for mmlu_pro_physics on rank 0...
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1299/1299 [00:01<00:00, 843.07it/s]
2024-09-21:20:20:48,589 INFO     [task.py:423] Building contexts for mmlu_pro_psychology on rank 0...
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 798/798 [00:00<00:00, 827.32it/s]
2024-09-21:20:20:49,653 INFO     [evaluator.py:465] Running generate_until requests
Running generate_until requests:   0%|                                                                                                                                                                                                                            | 0/12032 [00:00<?, ?it/s]torch.Size([1, 3121])
Traceback (most recent call last):
  File "/home/salman/.pyenv/versions/3.11.9/envs/tune/bin/tune", line 8, in <module>
    sys.exit(main())
             ^^^^^^
  File "/home/salman/torchtune/torchtune/_cli/tune.py", line 49, in main
    parser.run(args)
  File "/home/salman/torchtune/torchtune/_cli/tune.py", line 43, in run
    args.func(args)
  File "/home/salman/torchtune/torchtune/_cli/run.py", line 185, in _run_cmd
    self._run_single_device(args)
  File "/home/salman/torchtune/torchtune/_cli/run.py", line 94, in _run_single_device
    runpy.run_path(str(args.recipe), run_name="__main__")
  File "<frozen runpy>", line 291, in run_path
  File "<frozen runpy>", line 98, in _run_module_code
  File "<frozen runpy>", line 88, in _run_code
  File "/home/salman/torchtune/recipes/eleuther_eval.py", line 303, in <module>
    sys.exit(recipe_main())
             ^^^^^^^^^^^^^
  File "/home/salman/torchtune/torchtune/config/_parse.py", line 99, in wrapper
    sys.exit(recipe_main(conf))
             ^^^^^^^^^^^^^^^^^
  File "/home/salman/torchtune/recipes/eleuther_eval.py", line 299, in recipe_main
    recipe.evaluate()
  File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/salman/torchtune/recipes/eleuther_eval.py", line 281, in evaluate
    output = evaluate(
             ^^^^^^^^^
  File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/lm_eval/utils.py", line 397, in _wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/lm_eval/evaluator.py", line 476, in evaluate
    resps = getattr(lm, reqtype)(cloned_reqs)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/lm_eval/models/huggingface.py", line 1279, in generate_until
    cont = self._model_generate(
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/salman/torchtune/recipes/eleuther_eval.py", line 147, in _model_generate
    toks, _ = generation.generate(
              ^^^^^^^^^^^^^^^^^^^^
  File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/salman/torchtune/torchtune/generation/_generation.py", line 312, in generate
    tokens, generated_logits = generate_next_token(
                               ^^^^^^^^^^^^^^^^^^^^
  File "/home/salman/torchtune/torchtune/generation/_generation.py", line 102, in generate_next_token
    logits = model(x, input_pos=input_pos, mask=mask)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1716, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/salman/.pyenv/versions/3.11.9/envs/tune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1727, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/salman/torchtune/torchtune/modules/transformer.py", line 591, in forward
    self._validate_inputs(
  File "/home/salman/torchtune/torchtune/modules/transformer.py", line 498, in _validate_inputs
    raise ValueError(
ValueError: seq_len (3121) of input tensor should be smaller than max_seq_len (2048)

cc @joecummings

@SalmanMohammadi
Copy link
Collaborator Author

We should address this by truncating the context ourselves in the generate call.

@joecummings
Copy link
Contributor

@SalmanMohammadi What's the update here? Eleuther is going to push a fix?

@SalmanMohammadi
Copy link
Collaborator Author

Yep yep. https://github.com/EleutherAI/lm-evaluation-harness/pull/2353/files
I'll patch it in and verify it works

@joecummings
Copy link
Contributor

Yep yep. EleutherAI/lm-evaluation-harness#2353 (files) I'll patch it in and verify it works

Not sure I follow - how does this deal with max_seq_len?

@SalmanMohammadi
Copy link
Collaborator Author

SalmanMohammadi commented Sep 25, 2024

After some investigation I think we'll need the above PR to land because it addresses these lines in lm_eval, in generate_until:

            # set the max length in tokens of inputs ("context_enc")
            if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
                # max len for inputs = max length, minus room to generate the max new tokens
                max_ctx_len = self.max_length - max_gen_toks
            elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
                # max len for inputs = encoder's whole max_length
                max_ctx_len = self.max_length

which determine the appropriate length to truncate the prompt to given the configured self.max_length. Since our model is neither of the above models, we don't correctly hit the logic for max_ctx_len. cc @baberabb to confirm

The actual truncation happens in self.tok_batch_encode and self.tok_encode which we need to add ourselves e.g.

# in lm_eval
            context_enc, attn_masks = self.tok_batch_encode(
                contexts,
                left_truncate_len=max_ctx_len,
                truncation=self.truncation,
            )

# in torchtune
    def tok_batch_encode(
        self, text: List[str], left_truncate_len: int = None, **kwargs
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        tokenized_text = [self.tok_encode(x) for x in text]

        # pad left
        x = left_pad_sequence(
            [torch.tensor(x) for x in tokenized_text],
            batch_first=True,
            padding_value=self._tokenizer.pad_id,
        )
        if left_truncate_len is not None:
            x = x[:, -left_truncate_len:]

        return x, torch.ones_like(x)  # return 'mask' b/c it's expected by the harness

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants