-
Notifications
You must be signed in to change notification settings - Fork 93
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Browse files
Browse the repository at this point in the history
- Loading branch information
1 parent
bffee7c
commit a2242e7
Showing
12 changed files
with
206 additions
and
54 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT license. | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT license. | ||
|
||
import os | ||
import argparse | ||
import torch | ||
import re | ||
|
||
from tutel.system import apply_rank_size_from_pattern | ||
|
||
def main(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--input_size', type=int, required=True) | ||
parser.add_argument('--inputs', type=str, required=True) | ||
parser.add_argument('--output', type=str, required=True) | ||
args = parser.parse_args() | ||
args.size = args.input_size | ||
|
||
mutate_size, expert_dict = {}, {} | ||
|
||
input_file = apply_rank_size_from_pattern(args.inputs, rank=0, size=args.size) | ||
state_dict = torch.load(input_file, map_location=torch.device('cpu')) | ||
for k in state_dict: | ||
if k.endswith('._num_global_experts'): | ||
entry = k[:k.rindex('.')] + '.experts.' | ||
mutate_size[entry] = int(state_dict[k]) | ||
|
||
if not mutate_size: | ||
raise Exception('No any Tutel MoE layer is found, as the provided checkpoint may be in legacy format. You need to reload this legacy checkpoint by corresponding application, re-checkpoint model\'s state_dict and get the latest format.') | ||
|
||
for rank in range(args.size): | ||
input_file = apply_rank_size_from_pattern(args.inputs, rank=rank, size=args.size) | ||
state_dict = torch.load(input_file, map_location=torch.device('cpu')) | ||
for k in state_dict: | ||
for e in mutate_size: | ||
if k.startswith(e): | ||
expert_dict[k] = expert_dict.get(k, [mutate_size[e],]) + [state_dict[k],] | ||
|
||
expert_dict = [(i, k, expert_dict[k]) for i, k in enumerate(expert_dict)] | ||
for i, k, v in expert_dict: | ||
num_global_experts, pieces = v[0], v[1:] | ||
if num_global_experts % args.size == 0: | ||
expert_dict[i] = torch.concat(pieces, dim=0).contiguous().clone() | ||
assert expert_dict[i].size(0) == num_global_experts, "Unexpected group size of expert with num_global_experts: %d v.s. %d. Maybe you set a wrong --size value." % (expert_dict[i].size(0), num_global_experts) | ||
elif args.size % num_global_experts == 0: | ||
expert_dict[i] = torch.concat(pieces, dim=0).contiguous() | ||
expert_dict[i] = expert_dict[i].view([num_global_experts, -1] + list(expert_dict[i].shape)[2:]).clone() | ||
else: | ||
raise Exception(f'Neither of "global_experts({num_global_experts}) / args.size({args.size})" nor "args.size({args.size}) / global_experts({num_global_experts})" is evenly divisible.') | ||
expert_dict[i] = (k, expert_dict[i]) | ||
|
||
expert_dict = dict(expert_dict) | ||
for k in state_dict: | ||
if k in expert_dict: | ||
state_dict[k] = expert_dict[k] | ||
torch.save(state_dict, args.output) | ||
|
||
if __name__ == "__main__": | ||
main() | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT license. | ||
|
||
import os | ||
import argparse | ||
import torch | ||
import re | ||
|
||
from tutel.system import apply_rank_size_from_pattern | ||
|
||
def main(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--output_size', type=int, required=True) | ||
parser.add_argument('--input', type=str, required=True) | ||
parser.add_argument('--outputs', type=str, required=True) | ||
args = parser.parse_args() | ||
args.size = args.output_size | ||
|
||
state_dict = torch.load(args.input, map_location=torch.device('cpu')) | ||
mutate_size, expert_dict = {}, {} | ||
|
||
for k in state_dict: | ||
if k.endswith('._num_global_experts'): | ||
entry = k[:k.rindex('.')] + '.experts.' | ||
mutate_size[entry] = int(state_dict[k]) | ||
|
||
if not mutate_size: | ||
raise Exception('No any Tutel MoE layer is found, as the provided checkpoint may be in legacy format. You need to reload this legacy checkpoint by corresponding application, re-checkpoint model\'s state_dict and get the latest format.') | ||
|
||
for k in state_dict: | ||
for e in mutate_size: | ||
if k.startswith(e): | ||
state = state_dict[k] | ||
shape = state.shape | ||
if shape[0] % args.size == 0: | ||
state = state.view([args.size, shape[0] // args.size] + list(shape)[1:]) | ||
elif args.size % shape[0] == 0: | ||
divisor = args.size // shape[0] | ||
for i in range(1, len(shape)): | ||
if shape[i] <= 1: | ||
continue | ||
assert shape[i] % divisor == 0, f"The second non-squeezable dimension is to be sliced to {divisor} pieces from an parameter of shape {shape}, which isn't divisible evenly." | ||
state = state.view([args.size,] + list(shape)[1:i] + [shape[i] // divisor,] + list(shape)[i+1:]) | ||
else: | ||
raise Exception(f'Neither of "global_experts({int(shape[0])}) / args.size({args.size})" nor "args.size({args.size}) / global_experts({int(shape[0])})" is evenly divisible.') | ||
expert_dict[k] = state | ||
|
||
for rank in range(args.size): | ||
generate_dict = dict() | ||
for k in state_dict: | ||
if k not in expert_dict: | ||
generate_dict[k] = state_dict[k] | ||
else: | ||
generate_dict[k] = expert_dict[k][rank, :].contiguous().clone() | ||
|
||
output_file = apply_rank_size_from_pattern(args.outputs, rank=rank, size=args.size) | ||
torch.save(generate_dict, output_file) | ||
|
||
if __name__ == "__main__": | ||
main() | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.