Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[nnx] improve toy examples #3813

Merged
merged 1 commit into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions flax/experimental/nnx/examples/toy_examples/01_functional_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,11 @@ def train_step(params, counts, batch):
x, y = batch

def loss_fn(params):
y_pred, (_, updates) = static.apply(params, counts)(x)
counts_ = updates.extract(Count)
model = static.merge(params, counts)
y_pred = model(x)
new_counts = model.extract(Count)
loss = jnp.mean((y - y_pred) ** 2)
return loss, counts_
return loss, new_counts

grad, counts = jax.grad(loss_fn, has_aux=True)(params)
# |-------- sgd ---------|
Expand All @@ -82,7 +83,8 @@ def loss_fn(params):
@jax.jit
def test_step(params: nnx.State, counts: nnx.State, batch):
x, y = batch
y_pred, _ = static.apply(params, counts)(x)
model = static.merge(params, counts)
y_pred = model(x)
loss = jnp.mean((y - y_pred) ** 2)
return {'loss': loss}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,43 +63,41 @@ def __call__(self, x):
optimizer = nnx.Optimizer(model, tx)

@nnx.jit
def train_step(optimizer: nnx.Optimizer, batch):
def train_step(model: MLP, optimizer: nnx.Optimizer, batch):
x, y = batch

def loss_fn(model: MLP):
y_pred = model(x)
return jnp.mean((y - y_pred) ** 2)

# |--default--|
grads: nnx.State = nnx.grad(loss_fn, wrt=nnx.Param)(optimizer.model)
# |--default--|
grads: nnx.State = nnx.grad(loss_fn, wrt=nnx.Param)(model)
# sgd update
optimizer.update(grads=grads)

# no return!!!


@nnx.jit
def test_step(optimizer: nnx.Optimizer, batch):
def test_step(model: MLP, batch):
x, y = batch
y_pred = optimizer.model(x)
y_pred = model(x)
loss = jnp.mean((y - y_pred) ** 2)
return {'loss': loss}


total_steps = 10_000
for step, batch in enumerate(dataset(32)):
train_step(optimizer, batch)
train_step(model, optimizer, batch)

if step % 1000 == 0:
logs = test_step(optimizer, (X, Y))
logs = test_step(model, (X, Y))
print(f"step: {step}, loss: {logs['loss']}")

if step >= total_steps - 1:
break

print('times called:', optimizer.model.count.value)
print('times called:', model.count.value)

y_pred = optimizer.model(X)
y_pred = model(X)

plt.scatter(X, Y, color='blue')
plt.plot(X, y_pred, color='black')
Expand Down
Loading