-
Notifications
You must be signed in to change notification settings - Fork 4.7k
Description
This change introduces a new way for AutoTP to handle work when split_shape isn't divisible by num_kv_heads. This was of sharding is done in mlp, lm_head, and embed_out as well.
In the new approach the total size is split into chunks of 64, and for the remainders a block of 64 is always needed which causes new allocations.
In addition to that, the allocations are always sent to the devices based on their ranks. Which means that devices with lower ranks will always get the extra allocations which end up in full memory and allocation failure.
We saw such failures in Llama 2-7B and llama_2_70b, where a lot of those extra allocations happen in MLP which ended up in allocation failures.
I am not sure what is the real motivation for this change, or what was reason behind choosing that number exactly. Were there any experiments with numbers. results, and conclusions for picking it?
It would be better if this number was configurable as well (in a json or an env variable).
The real issue is finding a better way to evenly distribute those allocations, because in the current implementation all those allocations are going to devices with low rank (0, 1, 2) and so on.
def get_shard_size(total_size, mp_size, name=None, rank=None): global num_kv_heads last_linear = ["lm_head", "embed_out"] # When we have num_kv_heads defined, uneven division is possible, otherwise enforce near even division if rank == None: rank = dist.get_rank() if num_kv_heads != None and total_size % num_kv_heads == 0 and "mlp" not in str(name) and str( name) not in last_linear: my_slices = (num_kv_heads // mp_size) + (1 if rank < (num_kv_heads % mp_size) else 0) return total_size * my_slices // num_kv_heads else: if total_size >= 64: grain_size = total_size // 64 return (grain_size // mp_size + (1 if rank < (grain_size % mp_size) else 0)) * 64 else: return total_size // mp_size + (1 if rank < (total_size % mp_size) else 0)