Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ademamix mixing #11

Open
edmondja opened this issue Nov 19, 2024 · 10 comments
Open

Ademamix mixing #11

edmondja opened this issue Nov 19, 2024 · 10 comments

Comments

@edmondja
Copy link

FYI, mixing with Ademamix gives great performances :
https://github.com/edmondja/AdEMAMix-ADOPT-Optimizer-Pytorch/blob/main/AdEMAMix-ADOPT.py

@And233
Copy link

And233 commented Nov 23, 2024

FYI, mixing with Ademamix gives great performances : https://github.com/edmondja/AdEMAMix-ADOPT-Optimizer-Pytorch/blob/main/AdEMAMix-ADOPT.py

Does it need higher VRAM? I dont know why, but it gets stuck after 2 steps. I can use the vanilla AdEMAmix with my 16G card on stable diffusion's lora training smoothly.

@edmondja
Copy link
Author

Ah interesting, it shouldnt as we are only dividing but some value already existing and we arent making a deepcopy of it

@And233
Copy link

And233 commented Nov 23, 2024

Ah interesting, it shouldnt as we are only dividing but some value already existing and we arent making a deepcopy of it

oh, maybe the precision? Can you make a 8-bit version, it seems that it is float32. bnb has a 8-bit implementation of the vanilla one: https://github.com/bitsandbytes-foundation/bitsandbytes/blob/main/bitsandbytes/optim/ademamix.py

@gesen2egee
Copy link

gesen2egee commented Nov 24, 2024

I modified line 88 of bnb under the # Update the EMAs section.
It runs, and the logs look okay. Let’s see if the results improve.

\Lib\site-packages\bitsandbytes\optim\ademamix.py

            # Update the EMAs
			if torch.min(nu) == 0:
				sq_norm = torch.ones_like(grad)
			else:
				sq_norm = torch.maximum(torch.sqrt(nu), torch.tensor(eps * 100))

            m1.mul_(beta1).addcdiv_(grad, sq_norm, alpha=1 - beta1)
            m2.mul_(beta3).addcdiv_(grad, sq_norm, alpha=1 - beta3)
            nu.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
            
            # Compute step

@gesen2egee
Copy link

I also noticed that adopt has been updated with add_clip to help stabilize the early stages.
How to implement the same operation in this version?

@gesen2egee
Copy link

gesen2egee commented Nov 27, 2024

            # Update the EMAs
            if torch.min(nu) == 0:
                sq_norm = torch.ones_like(grad)
            else:
                sq_norm = torch.maximum(torch.sqrt(nu), torch.tensor(eps * 100))

            m1.mul_(beta1).addcdiv_(grad, sq_norm, alpha=1 - beta1)
            m2.mul_(beta3).addcdiv_(grad, sq_norm, alpha=1 - beta3)
            nu.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)

            # Compute step
            exp_avg = (m1.div(bias_correction1) + alpha * m2) 
            denom = (nu.sqrt() / (bias_correction2**0.5)).add(eps)

            # Compute norm gradient
            mask = (exp_avg * grad > 0).to(grad.dtype)
            mask = mask * (mask.numel() / (mask.sum() + 1))
            norm_grad = (exp_avg * mask) / denom

            clip = group["step"] ** 0.25
            norm_grad.clamp_(-clip, clip)
            
            # Use stableadamw
            rms_min = 1.0
            rms = torch.div(
                grad.pow(2), 
                torch.maximum(nu, (eps ** 2) * torch.ones_like(nu))
            ).mean().sqrt().item()
            new_lr = lr * (1. / max(1., rms / rms_min))

            p.add_(norm_grad, alpha=-new_lr)

            # Add weight decay
            if weight_decay > 0.0:
                p.add_(p, alpha=-new_lr * weight_decay)

    return loss

Try adding this.

Cautious Optimizer (C-Optiom): Improving Training with One Line of Code
https://github.com/kyleliang919/C-Optim/tree/main

And stableadamw
https://gist.github.com/mitchellnw/d42e22a0b9ec02ceaf4f7b4457f51423

Not sure if it's correct, but it seems pretty good by the training results.

@And233
Copy link

And233 commented Nov 28, 2024

            # Update the EMAs
            if torch.min(nu) == 0:
                sq_norm = torch.ones_like(grad)
            else:
                sq_norm = torch.maximum(torch.sqrt(nu), torch.tensor(eps * 100))

            m1.mul_(beta1).addcdiv_(grad, sq_norm, alpha=1 - beta1)
            m2.mul_(beta3).addcdiv_(grad, sq_norm, alpha=1 - beta3)
            nu.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)

            # Compute step
            exp_avg = (m1.div(bias_correction1) + alpha * m2) 
            denom = (nu.sqrt() / (bias_correction2**0.5)).add(eps)

            # Compute norm gradient
            mask = (exp_avg * grad > 0).to(grad.dtype)
            mask = mask * (mask.numel() / (mask.sum() + 1))
            norm_grad = (exp_avg * mask) / denom

            clip = group["step"] ** 0.25
            norm_grad.clamp_(-clip, clip)
            
            # Use stableadamw
            rms_min = 1.0
            rms = torch.div(
                grad.pow(2), 
                torch.maximum(nu, (eps ** 2) * torch.ones_like(nu))
            ).mean().sqrt().item()
            new_lr = lr * (1. / max(1., rms / rms_min))

            p.add_(norm_grad, alpha=-new_lr)

            # Add weight decay
            if weight_decay > 0.0:
                p.add_(p, alpha=-new_lr * weight_decay)

    return loss

Try adding this.

Cautious Optimizer (C-Optiom): Improving Training with One Line of Code https://github.com/kyleliang919/C-Optim/tree/main

And stableadamw https://gist.github.com/mitchellnw/d42e22a0b9ec02ceaf4f7b4457f51423

Not sure if it's correct, but it seems pretty good by the training results.

That's great! I try 60ep on my small datasets for sdxl lora training (also use with immiscible noise), it seems to be better than the 150ep version with old method.

@gesen2egee
Copy link

gesen2egee commented Nov 28, 2024

@And233 Thank you

Another modification: Add automatic warm-up like RAdam. You can try using a higher learning rate combined with a decay scheduler (e.g., 1e-3 ). From my testing, it looks much better than before.

The changes also start at line 88, to the return loss

            # ADOPT (https://github.com/iShohei220/adopt/blob/main/adopt.py)
            step = group["step"]
            denom = torch.clamp(nu.sqrt(), eps) 
            norm_grad = grad / denom
            clip = step ** 0.25  # Define clip threshold
            norm_grad.clamp_(-clip, clip)

            # AdEMAMix (https://github.com/bitsandbytes-foundation/bitsandbytes/blob/main/bitsandbytes/optim/ademamix.py)
            m1.mul_(beta1).add_(norm_grad, alpha=1 - beta1)
            m2.mul_(beta3).add_(norm_grad, alpha=1 - beta3)
            nu.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)

            # C_AdamW (https://github.com/kyleliang919/C-Optim/tree/main)
            mask = (m1 * grad > 0).to(grad.dtype)
            mask = mask * (mask.numel() / (mask.sum() + 1))
            masked_m1 = m1 * mask
            
            update = masked_m1.div(bias_correction1) + (alpha * m2)

            # StableAdamW (https://gist.github.com/mitchellnw/d42e22a0b9ec02ceaf4f7b4457f51423)
            rms_min = 1.0
            rms = torch.div(
                grad.pow(2), 
                torch.maximum(nu, (eps ** 2) * torch.ones_like(nu))
            ).mean().sqrt().item()

            new_lr = lr * (1. / max(1., rms / rms_min))
            
            # RAdam (https://pytorch.org/docs/stable/_modules/torch/optim/radam.html#RAdam)
            rho_inf = 2 / (1 - beta2) - 1
            rho_t = rho_inf - 2 * step * (beta2**step) / bias_correction2

            def _compute_rect():
                return (
                    (rho_t - 4)
                    * (rho_t - 2)
                    * rho_inf
                    / ((rho_inf - 4) * (rho_inf - 2) * rho_t)
                ) ** 0.5

            def _compute_adaptive_lr():
                exp_avg_sq_sqrt = nu.sqrt()
                exp_avg_sq_sqrt = exp_avg_sq_sqrt.add_(eps)
                return (bias_correction2**0.5) / exp_avg_sq_sqrt

            if rho_t > 5.0:
                p.add_(
                    update
                    * new_lr
                    * _compute_adaptive_lr()
                    * _compute_rect(),
                    alpha=-1.0,
                )
            else:
                p.add_(update * new_lr, alpha=-1.0)

            # Add weight decay
            if weight_decay > 0.0:
                p.add_(p, alpha=-new_lr * weight_decay)

    return loss

@And233
Copy link

And233 commented Nov 28, 2024

@And233 Thank you

Another modification: Add automatic warm-up like RAdam. You can try using a higher learning rate combined with a decay scheduler (e.g., 1e-3 ). From my testing, it looks much better than before.

The changes also start at line 88, to the return loss

            # ADOPT (https://github.com/iShohei220/adopt/blob/main/adopt.py)
            step = group["step"]
            denom = torch.clamp(nu.sqrt(), eps) 
            norm_grad = grad / denom
            clip = step ** 0.25  # Define clip threshold
            norm_grad.clamp_(-clip, clip)

            # AdEMAMix (https://github.com/bitsandbytes-foundation/bitsandbytes/blob/main/bitsandbytes/optim/ademamix.py)
            m1.mul_(beta1).add_(norm_grad, alpha=1 - beta1)
            m2.mul_(beta3).add_(norm_grad, alpha=1 - beta3)
            nu.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)

            # C_AdamW (https://github.com/kyleliang919/C-Optim/tree/main)
            mask = (m1 * grad > 0).to(grad.dtype)
            mask = mask * (mask.numel() / (mask.sum() + 1))
            masked_m1 = m1 * mask
            
            update = masked_m1.div(bias_correction1) + (alpha * m2)

            # StableAdamW (https://gist.github.com/mitchellnw/d42e22a0b9ec02ceaf4f7b4457f51423)
            rms_min = 1.0
            rms = torch.div(
                grad.pow(2), 
                torch.maximum(nu, (eps ** 2) * torch.ones_like(nu))
            ).mean().sqrt().item()

            new_lr = lr * (1. / max(1., rms / rms_min))
            
            # RAdam (https://pytorch.org/docs/stable/_modules/torch/optim/radam.html#RAdam)
            rho_inf = 2 / (1 - beta2) - 1
            rho_t = rho_inf - 2 * step * (beta2**step) / bias_correction2

            def _compute_rect():
                return (
                    (rho_t - 4)
                    * (rho_t - 2)
                    * rho_inf
                    / ((rho_inf - 4) * (rho_inf - 2) * rho_t)
                ) ** 0.5

            def _compute_adaptive_lr():
                exp_avg_sq_sqrt = nu.sqrt()
                exp_avg_sq_sqrt = exp_avg_sq_sqrt.add_(eps)
                return (bias_correction2**0.5) / exp_avg_sq_sqrt

            if rho_t > 5.0:
                param.add_(
                    update
                    * new_lr
                    * _compute_adaptive_lr()
                    * _compute_rect(),
                    alpha=-1.0,
                )
            else:
                param.add_(update * new_lr, alpha=-1.0)

            # Add weight decay
            if weight_decay > 0.0:
                p.add_(p, alpha=-new_lr * weight_decay)

    return loss

Do you mean try a scheduler like CosineAnnealing? I used to set a 20% warm up and constant scheduler with a high lr(5e-3). Could this RAdam take place of warmup steps?

@gesen2egee
Copy link

RAdam automatically adjusts the learning rate to wait for an appropriate variance, so it should be able to replace warm-up.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants