Skip to content

Commit fd4d9ef

Browse files
authored
Wave scaling patch (#35)
* quick hotfix for transformer models * added logging * changed spacing * fixed division by 0
1 parent e2dbb36 commit fd4d9ef

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

analyzer/habitat/analysis/wave_scaling/resimplified.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33
from habitat.analysis.kernels import PredictedKernel
44
from habitat.analysis.wave_scaling.common import calculate_wave_info
5-
5+
import logging
6+
logger = logging.getLogger(__name__)
67

78
def resimplified_wave_scaling(
89
kernel,
@@ -23,8 +24,9 @@ def resimplified_wave_scaling(
2324
# Check if the kernel is too "small" - if it doesn't fill a single wave
2425
# on the current device AND if it doesn't fill a single wave on the
2526
# destination device
26-
if (kernel.num_blocks // origin_wave_size == 0 and
27-
kernel.num_blocks // dest_wave_size == 0):
27+
if (origin_wave_size == 0 or dest_wave_size == 0):
28+
logger.warn(f"One or more invalid wave sizes: kernel: {kernel.name} origin: {origin_wave_size}, dest: {dest_wave_size}")
29+
if ((origin_wave_size == 0 or dest_wave_size == 0) or (kernel.num_blocks // origin_wave_size == 0 and kernel.num_blocks // dest_wave_size == 0)):
2830
# We scale the run time by the compute factor only
2931
origin_max_occupancy = math.ceil(
3032
kernel.num_blocks / origin_device.num_sms

analyzer/habitat/analysis/wave_scaling/roofline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ def roofline_wave_scaling(
2626
# 1. Check if the kernel is too "small" - if it doesn't fill a single wave
2727
# on the current device AND if it doesn't fill a single wave on the
2828
# destination device
29-
if (kernel.num_blocks // origin_wave_size == 0 and
30-
kernel.num_blocks // dest_wave_size == 0):
29+
if ((origin_wave_size == 0 or dest_wave_size == 0) or (kernel.num_blocks // origin_wave_size == 0 and
30+
kernel.num_blocks // dest_wave_size == 0)):
3131
# We scale the run time by the compute factor only
3232
origin_max_occupancy = math.ceil(
3333
kernel.num_blocks / origin_device.num_sms

0 commit comments

Comments
 (0)