Skip to content

Commit 03cde49

Browse files
chuanqi129facebook-github-bot
authored andcommitted
Align xpu models batch size with A100 (#2378)
Summary: To align xpu batchsize for dynamobenchmark torchbench suite Pull Request resolved: #2378 Reviewed By: aaronenyeshi Differential Revision: D59961717 Pulled By: xuzhao9 fbshipit-source-id: c926d6d14d8b979284aa465738132c17425dafe9
1 parent 11cf319 commit 03cde49

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

torchbenchmark/util/model.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
)
3232
from torchbenchmark.util.input import input_cast, ModelInputDescriptor
3333

34-
SPECIAL_DEVICE_MAPPING = {"AMD Instinct MI210": "NVIDIA A100-SXM4-40GB"}
34+
SPECIAL_DEVICE_MAPPING = {"AMD Instinct MI210": "NVIDIA A100-SXM4-40GB", "Intel(R) Data Center GPU Max 1100": "NVIDIA A100-SXM4-40GB", "Intel(R) Data Center GPU Max 1550": "NVIDIA A100-SXM4-40GB"}
3535

3636

3737
class PostInitProcessor(type):
@@ -211,16 +211,24 @@ def _determine_dynamic_num_batches(
211211
return 1
212212

213213
def _get_batch_size_from_metadata(self) -> Optional[str]:
214-
if self.device != "cuda":
215-
current_device_name = str(self.device)
216-
else:
214+
if self.device == "cuda":
217215
current_device_name = (
218216
torch.cuda.get_device_name()
219217
if torch.cuda.get_device_name()
220218
else "UNKNOWN"
221219
)
222220
if current_device_name in SPECIAL_DEVICE_MAPPING:
223221
current_device_name = SPECIAL_DEVICE_MAPPING[current_device_name]
222+
elif self.device == "xpu":
223+
current_device_name = (
224+
torch.xpu.get_device_name()
225+
if torch.xpu.get_device_name()
226+
else "UNKNOWN"
227+
)
228+
if current_device_name in SPECIAL_DEVICE_MAPPING:
229+
current_device_name = SPECIAL_DEVICE_MAPPING[current_device_name]
230+
else:
231+
current_device_name = str(self.device)
224232

225233
# use the device suggestion on CUDA inference tests, key should be either eval_batch_size or train_batch_size
226234
device_batch_size_key = f"{self.test}_batch_size"

0 commit comments

Comments
 (0)