|
3 | 3 | # |
4 | 4 | # This source code is licensed under the BSD 3-Clause license found in the |
5 | 5 | # LICENSE file in the root directory of this source tree. |
6 | | -import random |
7 | 6 | from typing import Optional |
8 | 7 |
|
9 | 8 | import fire |
10 | 9 | import pandas as pd |
11 | 10 | import torch |
12 | 11 | from utils import do_benchmarks, get_name_to_moe_shapes_iter |
13 | 12 |
|
| 13 | +from torchao.prototype.moe_training.utils import generate_jagged_offs |
14 | 14 | from torchao.testing.training.roofline_utils import get_specs |
15 | 15 |
|
16 | 16 |
|
@@ -146,39 +146,6 @@ def do_scaled_grouped_mm(A, B): |
146 | 146 | data_df.to_csv(out_filename) |
147 | 147 |
|
148 | 148 |
|
149 | | -def generate_jagged_offs(E, M, dtype=torch.int32, device="cuda"): |
150 | | - """ |
151 | | - Generates a tensor of length E, containing random values divisible by 16, |
152 | | - from 0 to M, in sorted order, and where the final value in the tensor is always M. |
153 | | - Args: |
154 | | - E (int): The length of the tensor. |
155 | | - M (int): The maximum value in the tensor. |
156 | | - Returns: |
157 | | - torch.Tensor: A tensor of length E with the specified properties. |
158 | | - """ |
159 | | - # Ensure M is divisible by 16 |
160 | | - if M % 16 != 0: |
161 | | - raise ValueError("M must be divisible by 16") |
162 | | - |
163 | | - # Generate a list of possible values |
164 | | - possible_values = [i for i in range(0, M + 1, 16)] |
165 | | - |
166 | | - # If E is larger than the number of possible values, raise an error |
167 | | - if E > len(possible_values): |
168 | | - raise ValueError("E cannot be larger than the number of possible values") |
169 | | - |
170 | | - # Randomly select E - 1 values from the possible values (excluding M) |
171 | | - selected_values = torch.tensor(random.sample(possible_values[:-1], E - 1)) |
172 | | - |
173 | | - # Append M to the selected values |
174 | | - selected_values = torch.cat((selected_values, torch.tensor([M]))) |
175 | | - |
176 | | - # Sort the selected values |
177 | | - selected_values, _ = torch.sort(selected_values) |
178 | | - |
179 | | - return selected_values.to(dtype).to(device) |
180 | | - |
181 | | - |
182 | 149 | def main() -> None: |
183 | 150 | fire.Fire(run) |
184 | 151 |
|
|
0 commit comments