From 6a80d6f99f28a15192e0c67de3ad0db365d205dd Mon Sep 17 00:00:00 2001 From: Han Guo Date: Tue, 10 Dec 2024 22:28:20 -0500 Subject: [PATCH] minor changes --- .github/workflows/wheels.yaml | 2 ++ flute/tune.py | 15 +++++++++------ 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/.github/workflows/wheels.yaml b/.github/workflows/wheels.yaml index 2805fdc..d993ed1 100644 --- a/.github/workflows/wheels.yaml +++ b/.github/workflows/wheels.yaml @@ -23,6 +23,8 @@ jobs: steps: - name: Checkout uses: actions/checkout@v4 + with: + submodules: recursive - name: Set up Linux Env if: ${{ runner.os == 'Linux' }} diff --git a/flute/tune.py b/flute/tune.py index bde63ec..38715c9 100644 --- a/flute/tune.py +++ b/flute/tune.py @@ -377,6 +377,7 @@ def tune_and_pack( group_size: int, num_seeds: int = 3, check_correctness: bool = True, + check_num_seeds: int = 3, ) -> Tuple[torch.Tensor, TuneMetaData]: if inputs.ndim != 2: raise ValueError @@ -422,12 +423,14 @@ def tune_and_pack( if check_correctness is True: for uniform in [True, False]: for identity in [True, False]: - check( - weight=weight, - weight_packed=weight_packed, - metadata=metadata, - uniform=uniform, - identity=identity) + for seed in range(check_num_seeds): + torch.manual_seed(seed) + check( + weight=weight, + weight_packed=weight_packed, + metadata=metadata, + uniform=uniform, + identity=identity) return weight_packed, metadata