You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This example shows how to use TutelMoE with Torch autocast amp.
Q: Is the All2All precision still meant to be done in FP32?
In general torch autocast amp keeps network master weights in FP32 and downcasts weights before a layers fwd pass.
So within an autocast context (as suggested here) the matmul here will be autocast to fp16
(Note torch.add is done in fp32; list of ops which are autocast; by default ops upcast to the input with highest precision; since batched_fc1_bias is fp32, the add will be done in fp32 and will output an fp32 answer)
So far everything is just standard torch.
My question is about these few lines of code. Since expert weights are in fp32, this will upcast input x to type fp32.
As a result the All2All communication is done using fp32 inputs.
Is this correct or am I missing some other cast?
(Note: in the cast at the end of this line, x has already been case to fp32).
It looks like the all2all is ALWAYS done using fp32 precision even if we are using an amp autocast context manager. Was this done deliberately or is this a bug? It seems like if the all2all is done using 16 bits we'd save 2x the BW.
Final note: as mentioned, if in the autocast context manager, although the all2all is done using fp32, autocast is still on and therefore the matmul's here are done using fp16.
Potential Bug: I'm not sure this does anything... That layer should already be in fp32 and when its run here autocast should still run it in fp16...
I THINK the right way to do this is something like this:
def forward(self, x):
if self.fp32_gate:
x = x.float()
with torch.autocast(device_type=x.device.type, enabled=not self.fp32_gate):
out = self.wg(x)
return out
Autocast is disabled here.
I'm pretty sure the suggested rewrite for gate autocast is correct and more understandable.
The text was updated successfully, but these errors were encountered:
In general it seems as though if the input is not explicitly cast here (ie we comment out those lines)
and the input to the MoE layer is in fp16, then the the all2all input is fp16 (and I'm assuming the all2all will be done in fp16); if the input to the MoE is in fp32 then the input to the all2all is in fp32 (and I'm assuming the all2all will be done in fp16)
This example shows how to use TutelMoE with Torch autocast amp.
Q: Is the All2All precision still meant to be done in FP32?
In general torch autocast amp keeps network master weights in FP32 and downcasts weights before a layers fwd pass.
So within an autocast context (as suggested here) the matmul here will be autocast to fp16
(Note torch.add is done in fp32; list of ops which are autocast; by default ops upcast to the input with highest precision; since
batched_fc1_bias
is fp32, the add will be done in fp32 and will output an fp32 answer)So far everything is just standard torch.
My question is about these few lines of code. Since expert weights are in fp32, this will upcast input x to type fp32.
As a result the All2All communication is done using fp32 inputs.
Is this correct or am I missing some other cast?
(Note: in the cast at the end of this line, x has already been case to fp32).
It looks like the all2all is ALWAYS done using fp32 precision even if we are using an amp autocast context manager. Was this done deliberately or is this a bug? It seems like if the all2all is done using 16 bits we'd save 2x the BW.
Final note: as mentioned, if in the autocast context manager, although the all2all is done using fp32, autocast is still on and therefore the matmul's here are done using fp16.
Potential Bug: I'm not sure this does anything... That layer should already be in fp32 and when its run here autocast should still run it in fp16...I THINK the right way to do this is something like this:
Autocast is disabled here.
I'm pretty sure the suggested rewrite for gate autocast is correct and more understandable.
The text was updated successfully, but these errors were encountered: