Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Masking certain parameters for weight decay in adamw #1007

Closed
AakashKumarNain opened this issue Jul 13, 2024 · 10 comments
Closed

Masking certain parameters for weight decay in adamw #1007

AakashKumarNain opened this issue Jul 13, 2024 · 10 comments

Comments

@AakashKumarNain
Copy link

AakashKumarNain commented Jul 13, 2024

I have a model built in Equinox, and I want to filter out parameters in a way that weight decay is applied only a certain subset of the original Pytree. But it seems that optax has a problem with pytrees passed as mask. Here is a MWE:

import jax
import equinox as eqx

class MLP(eqx.Module):
    fc1: eqx.nn.Linear
    fc2: eqx.nn.Linear
    
    def __init__(self, key, dtype=jnp.bfloat16):
        key1, key2 = jax.random.split(key, 2)
        self.fc1 = eqx.nn.Linear(32, 64, key=key1, dtype=dtype)
        self.fc2 = eqx.nn.Linear(64, 64, key=key2, dtype=dtype)

    def __call__(self, x):
       pass


class Attention(eqx.Module):
    wqkv: eqx.nn.Linear
    proj: eqx.nn.Linear
    drop: eqx.nn.Dropout
    
    def __init__(self, key, dtype=jnp.bfloat16):
        key1, key2 = jax.random.split(key, 2)
        self.wqkv = eqx.nn.Linear(64, 3 * 64, key=key1) # 3 for qkv
        self.proj = eqx.nn.Linear(64, 64, key=key2)
        self.drop = eqx.nn.Dropout()

    def __call__(self, x, mask=None):
        pass


class TransformerBlock(eqx.Module):
    norm_1: eqx.nn.LayerNorm
    norm_2: eqx.nn.LayerNorm
    attn: Attention
    mlp: MLP

    def __init__(self, key, dtype=jnp.bfloat16):
        key1, key2 = jax.random.split(key, 2)
        self.norm_1 = eqx.nn.LayerNorm(64)
        self.attn = Attention(key=key1, dtype=dtype)
        self.norm_2 = eqx.nn.LayerNorm(64)
        self.mlp = MLP(key=key2, dtype=dtype)

    def __call__(self, x, mask=None):
        pass


class Transformer(eqx.Module):
    pos_embed: eqx.nn.Embedding
    tf_blocks: TransformerBlock
    norm: eqx.nn.LayerNorm

    def __init__(self, key, num_layers=2, dtype=jnp.bfloat16):
        keys = jax.random.split(key, num_layers + 3)
        key1, key2, key3, tf_keys = keys[0], keys[1], keys[2], keys[3:]

        self.tf_blocks = [TransformerBlock(tf_keys[i]) for i in range(num_layers)]
        self.norm = eqx.nn.LayerNorm(64)
        self.pos_embed = eqx.nn.Embedding(64, 64, key=key1)

    def __call__(self, x, y, mask=None):
        pos_embed = jax.vmap(self.pos_embed)(y)




def is_layer(x):
    return isinstance(x, eqx.nn.Linear) or isinstance(x, eqx.nn.LayerNorm)

def set_mask(x):
    if isinstance(x, eqx.nn.Linear):
        return jtu.tree_map(lambda _: True, x)
    elif isinstance(x, eqx.nn.LayerNorm):
        mask = jtu.tree_map(lambda _: False, x)
        mask = eqx.tree_at(lambda m: m.bias, mask, True)
        return mask
    else:
        return jtu.tree_map(lambda _: False, x)


model = Transformer(jax.random.PRNGKey(1))
params = eqx.filter(model, eqx.is_array)
mask = jtu.tree_map(set_mask, params, is_leaf=is_layer)
optim = optax.adamw(learning_rate=1e-4, mask=mask)
opt_state = optim.init(params)

Traceback

---> 83 opt_state = optim.init(params)

File ~/miniconda3/envs/jaxenv/lib/python3.11/site-packages/optax/_src/combine.py:64, in chain.<locals>.init_fn(params)
     63 def init_fn(params):
---> 64   return tuple(fn(params) for fn in init_fns)

File ~/miniconda3/envs/jaxenv/lib/python3.11/site-packages/optax/_src/combine.py:64, in <genexpr>(.0)
     63 def init_fn(params):
---> 64   return tuple(fn(params) for fn in init_fns)

File ~/miniconda3/envs/jaxenv/lib/python3.11/site-packages/optax/_src/wrappers.py:544, in masked.<locals>.init_fn(params)
    541 if isinstance(params, _state_utils._ParamsPlaceholder):  # pylint:disable=protected-access
    542   return MaskedState(inner_state=inner.init(params))
--> 544 mask_tree = mask(params) if callable(mask) else mask
    545 masked_params = mask_pytree(params, mask_tree)
    546 return MaskedState(inner_state=inner.init(masked_params))

TypeError: Transformer.__call__() missing 1 required positional argument: 'y'

You can find the related discussion: patrick-kidger/equinox#771

@AakashKumarNain AakashKumarNain changed the title Making certain parameters for weight decay in adamw Masking certain parameters for weight decay in adamw Jul 13, 2024
@vroulet
Copy link
Collaborator

vroulet commented Jul 17, 2024

Hello @AakashKumarNain

This was raised in #913 and @JadM133 found a solution. PR #1015 should fix this.

@AakashKumarNain
Copy link
Author

AakashKumarNain commented Jul 18, 2024

Thanks @vroulet for pointing out the PR. I hope it gets merged soon because this has been a huge blocker for the Equinox users. Also, do you have any immediate suggestion to make it work for now?

@JadM133
Copy link

JadM133 commented Jul 18, 2024

Hello @AakashKumarNain , in the meantime, you can modify two lines in _src/wrappers.py:

Line 544: mask_tree = mask (instead of "mask(params) if callable(mask) else mask")
Line 549: mask_tree = mask (same change)

This should get your code to run as expected until the pull request is merged.

@AakashKumarNain
Copy link
Author

Thanks @JadM133 for the suggestion. I will try it out

@vroulet
Copy link
Collaborator

vroulet commented Jul 18, 2024

The PR has been merged. You'll need to install optax locally to use it (we may not release a new version soon).

@vroulet vroulet closed this as completed Jul 18, 2024
@AakashKumarNain
Copy link
Author

Thank you. I will do a local install

@AakashKumarNain
Copy link
Author

@vroulet @JadM133 I did a local install, and though masking works, it actually broke the optim.update(...) functionality for adamw. Works fine for other optimizers like adam. We should reopen this issue

@vroulet vroulet reopened this Jul 23, 2024
@vroulet
Copy link
Collaborator

vroulet commented Jul 23, 2024

Hello @AakashKumarNain

Could you send the exact bug and a minimal reproducing example?
I've tried out the current code (see below) and don't get errors.

import equinox as eqx
import jax.numpy as jnp
import jax.random as jrd
import jax.tree_util as jtu
import jax

import optax
import optax.tree_utils as otu

# With standard pytrees
def test_mask(mask):
  opt1 = optax.adamw(1., mask=mask, weight_decay=1.)
  opt2 = optax.adamw(1., mask=mask, weight_decay=0.)
  state1 = opt1.init(params)
  state2 = opt2.init(params)
  def fun(x):
    return otu.tree_l2_norm(x, squared=True)
  grad = jax.grad(fun)(params)
  u1, _ = opt1.update(grad, state1, params)
  u2, _ = opt2.update(grad, state2, params)
  optax.apply_updates(params, u1)
  print(f'Did mask work?: {jnp.allclose(u1[1], u2[1]) and not jnp.allclose(u1[0], u2[0])}')

params = [jnp.array([[1., 2.], [3., 4.]]), jnp.array([5., 6.])]
mask_fn = lambda p: jtu.tree_map(lambda x: x.ndim != 1, p)
mask = mask_fn(params)

test_mask(mask)
test_mask(mask_fn)

# Equinox setting
@eqx.filter_value_and_grad
def grad_loss(model, input, output):
    pred = model(input)
    mse = lambda x, y : jnp.mean(jnp.square(x-y))
    return mse(pred, output)

@eqx.filter_jit
def make_step(input, output, model, states):
    loss, grads = grad_loss(model, input, output)
    u1, opt_state = optim1.update(grads, states[0], model)
    u2, opt_state = optim2.update(grads, states[1], model)
    is_working = jnp.allclose(u1.layers[0].bias, u2.layers[0].bias) & (~jnp.allclose(u1.layers[0].weight, u2.layers[0].weight))
    jax.debug.print('Did mask work?: {}', is_working)
    model = eqx.apply_updates(model, u1)
    return loss, model, opt_state

key, subkey = jax.random.split(jrd.PRNGKey(0))
xs = jnp.ones((100,))
ys = jax.random.normal(key, (1,))

model = eqx.nn.MLP(xs.shape[-1], ys.shape[-1], 10, 1, key=subkey)

lr = 1e-2
filter_spec = jtu.tree_map(lambda _: True, model)
filter_spec = eqx.tree_at(
    lambda tree: (tree.layers[0].bias, tree.layers[1].bias),
    filter_spec,
    replace=(False, False),
)

optim1 = optax.adamw(1., mask=filter_spec, weight_decay=1.)
optim2 = optax.adamw(1., mask=filter_spec, weight_decay=0.)

state1 =  optim1.init(eqx.filter(model, eqx.is_inexact_array))
state2 =  optim1.init(eqx.filter(model, eqx.is_inexact_array))

loss, model, opt_state = make_step(xs, ys, model, (state1, state2))
Did mask work?: True
Did mask work?: True
Did mask work?: True

@JadM133
Copy link

JadM133 commented Jul 23, 2024

Hello @vroulet , @AakashKumarNain , I think the issue is not with the mask. I assume @AakashKumarNain is getting the following error:

ValueError: You are using a transformation that requires the current value of parameters, but you are not passing `params` when calling `update`."""

The problem is that the update function of adamw is different than the others and requires param (as mentioned in the documentation). So to use adamw, using the same code as adam is changing the name of the optimizer won't work. Some changes should be done to the code as the one written by @vroulet above.

@AakashKumarNain
Copy link
Author

Yup, that's the part I missed! Thanks @vroulet @JadM133 for the help.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants