Skip to content

Commit

Permalink
Fix attn mask ignore logic in training-time trace (huggingface#32613)
Browse files Browse the repository at this point in the history
* fix attn mask logic for training-time trace

* add test

* fix

* fix

* fix

* fix

* fix

* format

* [run-slow] llama

* avoid accelearate

* [run-slow] llama
  • Loading branch information
zhenglongjiepheonix authored and NielsRogge committed Oct 21, 2024
1 parent 25cb920 commit 8775b8c
Show file tree
Hide file tree
Showing 7 changed files with 55 additions and 5 deletions.
1 change: 0 additions & 1 deletion src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1284,7 +1284,6 @@ def __init__(
max_batch_size: Optional[int] = None,
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
) -> None:
super().__init__()
if not hasattr(config, "sliding_window") or config.sliding_window is None:
raise ValueError(
"Setting `cache_implementation` to 'sliding_window' requires the model config supporting "
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/modeling_attn_mask_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def _ignore_causal_mask_sdpa(
elif sliding_window is None or key_value_length < sliding_window:
if len(attention_mask.shape) == 4:
return False
elif (is_training or not is_tracing) and torch.all(attention_mask == 1):
elif not is_tracing and torch.all(attention_mask == 1):
if query_length == 1 or key_value_length == query_length:
# For query_length == 1, causal attention and bi-directional attention are the same.
ignore_causal_mask = True
Expand Down
5 changes: 4 additions & 1 deletion tests/models/gemma/test_modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,9 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
# used in `test_torch_compile`
_torch_compile_test_ckpt = "google/gemma-2b"

# used in `test_torch_compile_for_training`
_torch_compile_train_cls = GemmaForCausalLM if is_torch_available() else None

# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
Expand Down Expand Up @@ -808,7 +811,7 @@ def test_compile_static_cache(self):

prompts = ["Hello I am doing", "Hi today"]
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b", pad_token="</s>", padding_side="right")
model = GemmaForCausalLM.from_pretrained("google/gemma-2b", device_map="sequential", torch_dtype=torch.float16)
model = GemmaForCausalLM.from_pretrained("google/gemma-2b", device_map=torch_device, torch_dtype=torch.float16)
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)

# Dynamic Cache
Expand Down
5 changes: 4 additions & 1 deletion tests/models/llama/test_modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,9 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
# used in `test_torch_compile`
_torch_compile_test_ckpt = "meta-llama/Llama-2-7b-hf"

# used in `test_torch_compile_for_training`
_torch_compile_train_cls = LlamaForCausalLM if is_torch_available() else None

def setUp(self):
self.model_tester = LlamaModelTester(self)
self.config_tester = ConfigTester(self, config_class=LlamaConfig, hidden_size=37)
Expand Down Expand Up @@ -874,7 +877,7 @@ def test_compile_static_cache(self):
]
tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token="</s>", padding_side="right")
model = LlamaForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf", device_map="sequential", torch_dtype=torch.float16
"meta-llama/Llama-2-7b-hf", device_map=torch_device, torch_dtype=torch.float16
)
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)

Expand Down
2 changes: 1 addition & 1 deletion tests/models/mistral/test_modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,7 @@ def test_compile_static_cache(self):
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=False)
tokenizer.pad_token = tokenizer.eos_token
model = MistralForCausalLM.from_pretrained(
"mistralai/Mistral-7B-v0.1", device_map="sequential", torch_dtype=torch.float16
"mistralai/Mistral-7B-v0.1", device_map=torch_device, torch_dtype=torch.float16
)
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)

Expand Down
2 changes: 2 additions & 0 deletions tests/models/nemotron/test_modeling_nemotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ class NemotronModelTest(GemmaModelTest):

# used in `test_torch_compile`
_torch_compile_test_ckpt = "nvidia/nemotron-3-8b-base-4k-hf"
# used in `test_torch_compile_for_training`
_torch_compile_train_cls = NemotronForCausalLM if is_torch_available() else None

def setUp(self):
self.model_tester = NemotronModelTester(self)
Expand Down
43 changes: 43 additions & 0 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4937,6 +4937,49 @@ def test_torch_compile(self):
for i in range(n_iter):
_ = model.generate(**input_ids, do_sample=False)

@slow
@require_torch_gpu
def test_torch_compile_for_training(self):
if version.parse(torch.__version__) < version.parse("2.3"):
self.skipTest(reason="This test requires torch >= 2.3 to run.")

if not hasattr(self, "_torch_compile_train_cls"):
self.skipTest(f"{self.__class__.__name__} doesn't have the attribute `_torch_compile_train_cls`.")

config, _ = self.model_tester.prepare_config_and_inputs_for_common()
cls = self._torch_compile_train_cls
model = cls(config).to(torch_device)

inputs = {
"input_ids": torch.randint(low=1, high=model.config.vocab_size, size=(2, 10), device=torch_device),
"attention_mask": torch.tensor(
[[1, 1, 1, 1, 1, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],
dtype=torch.int64,
device=torch_device,
),
"position_ids": torch.arange(0, 10, device=torch_device).unsqueeze(0),
"labels": torch.randint(low=1, high=model.config.vocab_size, size=(2, 10), device=torch_device),
}

# eager backward
set_seed(42)
loss = model(**inputs).loss
loss.backward()

params = {name: param.grad.clone().detach().cpu() for name, param in model.named_parameters()}
model.zero_grad()
del loss

model = torch.compile(model, fullgraph=True, mode="reduce-overhead")
# forward compilation
set_seed(42)
loss = model(**inputs).loss
# backward compilation
loss.backward()
# check grad matches
for name, param in model._orig_mod.named_parameters():
torch.testing.assert_close(param.grad.detach().cpu(), params[name], rtol=1e-4, atol=1e-4)

@slow
@require_torch_gpu # Testing cuda graphs.
@require_read_token
Expand Down

0 comments on commit 8775b8c

Please sign in to comment.