Skip to content

Commit 706703b

Browse files
authored
Expectations test utils (#36569)
* Add expectation classes + tests * Use typing Union instead of | * Use bits to track score in properties cmp method * Add exceptions and tests + comments * Remove compute cap minor as it is not needed currently * Simplify. Remove Properties class * Add example Exceptions usage * Expectations as dict subclass * Update example Exceptions usage * Refactor. Improve type name. Document score fn. * Rename to DeviceProperties.
1 parent 179d02f commit 706703b

File tree

3 files changed

+125
-19
lines changed

3 files changed

+125
-19
lines changed

src/transformers/testing_utils.py

Lines changed: 78 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,13 @@
3232
import threading
3333
import time
3434
import unittest
35-
from collections import defaultdict
35+
from collections import UserDict, defaultdict
3636
from collections.abc import Mapping
3737
from dataclasses import MISSING, fields
38-
from functools import wraps
38+
from functools import cache, wraps
3939
from io import StringIO
4040
from pathlib import Path
41-
from typing import Callable, Dict, Generator, Iterable, Iterator, List, Optional, Union
41+
from typing import Any, Callable, Dict, Generator, Iterable, Iterator, List, Optional, Union
4242
from unittest import mock
4343
from unittest.mock import patch
4444

@@ -3042,3 +3042,78 @@ def cleanup(device: str, gc_collect=False):
30423042
if gc_collect:
30433043
gc.collect()
30443044
backend_empty_cache(device)
3045+
3046+
3047+
# Type definition of key used in `Expectations` class.
3048+
DeviceProperties = tuple[Union[str, None], Union[int, None]]
3049+
3050+
3051+
@cache
3052+
def get_device_properties(self) -> DeviceProperties:
3053+
"""
3054+
Get environment device properties.
3055+
"""
3056+
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
3057+
import torch
3058+
3059+
major, _ = torch.cuda.get_device_capability()
3060+
if IS_ROCM_SYSTEM:
3061+
return ("rocm", major)
3062+
else:
3063+
return ("cuda", major)
3064+
else:
3065+
return (torch_device, None)
3066+
3067+
3068+
class Expectations(UserDict[DeviceProperties, Any]):
3069+
def get_expectation(self) -> Any:
3070+
"""
3071+
Find best matching expectation based on environment device properties.
3072+
"""
3073+
return self.find_expectation(get_device_properties())
3074+
3075+
@staticmethod
3076+
def is_default(key: DeviceProperties) -> bool:
3077+
return all(p is None for p in key)
3078+
3079+
@staticmethod
3080+
def score(key: DeviceProperties, other: DeviceProperties) -> int:
3081+
"""
3082+
Returns score indicating how similar two instances of the `Properties` tuple are.
3083+
Points are calculated using bits, but documented as int.
3084+
Rules are as follows:
3085+
* Matching `type` gives 8 points.
3086+
* Semi-matching `type`, for example cuda and rocm, gives 4 points.
3087+
* Matching `major` (compute capability major version) gives 2 points.
3088+
* Default expectation (if present) gives 1 points.
3089+
"""
3090+
(device_type, major) = key
3091+
(other_device_type, other_major) = other
3092+
3093+
score = 0b0
3094+
if device_type == other_device_type:
3095+
score |= 0b1000
3096+
elif device_type in ["cuda", "rocm"] and other_device_type in ["cuda", "rocm"]:
3097+
score |= 0b100
3098+
3099+
if major == other_major and other_major is not None:
3100+
score |= 0b10
3101+
3102+
if Expectations.is_default(other):
3103+
score |= 0b1
3104+
3105+
return int(score)
3106+
3107+
def find_expectation(self, key: DeviceProperties = (None, None)) -> Any:
3108+
"""
3109+
Find best matching expectation based on provided device properties.
3110+
"""
3111+
(result_key, result) = max(self.data.items(), key=lambda x: Expectations.score(key, x[0]))
3112+
3113+
if Expectations.score(key, result_key) == 0:
3114+
raise ValueError(f"No matching expectation found for {key}")
3115+
3116+
return result
3117+
3118+
def __repr__(self):
3119+
return f"{self.data}"

tests/models/bamba/test_modeling_bamba.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,7 @@
2020
import pytest
2121

2222
from transformers import AutoTokenizer, BambaConfig, is_torch_available
23-
from transformers.testing_utils import (
24-
require_torch,
25-
require_torch_gpu,
26-
slow,
27-
torch_device,
28-
)
23+
from transformers.testing_utils import Expectations, require_torch, require_torch_gpu, slow, torch_device
2924

3025
from ...generation.test_utils import GenerationTesterMixin
3126
from ...test_configuration_common import ConfigTester
@@ -503,15 +498,18 @@ def setUpClass(cls):
503498
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
504499

505500
def test_simple_generate(self):
506-
# Key 9 for MI300, Key 8 for A100/A10, and Key 7 for T4.
507-
#
508-
# Note: Key 9 is currently set for MI300, but may need potential future adjustments for H100s,
509-
# considering differences in hardware processing and potential deviations in generated text.
510-
EXPECTED_TEXTS = {
511-
# 7: "",
512-
8: "<|begin_of_text|>Hey how are you doing on this lovely evening? I hope you are all having a good time.",
513-
9: "<|begin_of_text|>Hey how are you doing on this lovely evening? I hope you are doing well. I am here",
514-
}
501+
expectations = Expectations(
502+
{
503+
(
504+
"cuda",
505+
8,
506+
): "<|begin_of_text|>Hey how are you doing on this lovely evening? I hope you are all having a good time.",
507+
(
508+
"rocm",
509+
9,
510+
): "<|begin_of_text|>Hey how are you doing on this lovely evening? I hope you are doing well. I am here",
511+
}
512+
)
515513

516514
self.model.to(torch_device)
517515

@@ -520,7 +518,8 @@ def test_simple_generate(self):
520518
].to(torch_device)
521519
out = self.model.generate(input_ids, do_sample=False, max_new_tokens=10)
522520
output_sentence = self.tokenizer.decode(out[0, :])
523-
self.assertEqual(output_sentence, EXPECTED_TEXTS[self.cuda_compute_capability_major_version])
521+
expected = expectations.get_expectation()
522+
self.assertEqual(output_sentence, expected)
524523

525524
# TODO: there are significant differences in the logits across major cuda versions, which shouldn't exist
526525
if self.cuda_compute_capability_major_version == 8:

tests/utils/test_expectations.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import unittest
2+
3+
from transformers.testing_utils import Expectations
4+
5+
6+
class ExpectationsTest(unittest.TestCase):
7+
def test_expectations(self):
8+
expectations = Expectations(
9+
{
10+
(None, None): 1,
11+
("cuda", 8): 2,
12+
("cuda", 7): 3,
13+
("rocm", 8): 4,
14+
("rocm", None): 5,
15+
("cpu", None): 6,
16+
}
17+
)
18+
19+
def check(value, key):
20+
assert expectations.find_expectation(key) == value
21+
22+
# xpu has no matches so should find default expectation
23+
check(1, ("xpu", None))
24+
check(2, ("cuda", 8))
25+
check(3, ("cuda", 7))
26+
check(4, ("rocm", 9))
27+
check(4, ("rocm", None))
28+
check(2, ("cuda", 2))
29+
30+
expectations = Expectations({("cuda", 8): 1})
31+
with self.assertRaises(ValueError):
32+
expectations.find_expectation(("xpu", None))

0 commit comments

Comments
 (0)