Skip to content

Commit

Permalink
refactor(examples): attempt to improve clarity of where some hard-cod…
Browse files Browse the repository at this point in the history
…ed numbers coming from or mean (#512)
  • Loading branch information
grzuy authored Nov 14, 2023
1 parent 8eb6f9b commit cd75977
Show file tree
Hide file tree
Showing 8 changed files with 71 additions and 39 deletions.
6 changes: 4 additions & 2 deletions examples/basics/multi_input_example.exs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ Mix.install([
defmodule XOR do
require Axon

@batch_size 32

defp build_model(input_shape1, input_shape2) do
inp1 = Axon.input("x1", shape: input_shape1)
inp2 = Axon.input("x2", shape: input_shape2)
Expand All @@ -18,8 +20,8 @@ defmodule XOR do
end

defp batch do
x1 = Nx.tensor(for _ <- 1..32, do: [Enum.random(0..1)])
x2 = Nx.tensor(for _ <- 1..32, do: [Enum.random(0..1)])
x1 = Nx.tensor(for _ <- 1..@batch_size, do: [Enum.random(0..1)])
x2 = Nx.tensor(for _ <- 1..@batch_size, do: [Enum.random(0..1)])
y = Nx.logical_xor(x1, x2)
{%{"x1" => x1, "x2" => x2}, y}
end
Expand Down
5 changes: 3 additions & 2 deletions examples/basics/multi_output_example.exs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ Mix.install([
defmodule Power do
require Axon

@batch_size 32

defp build_model do
fc =
Axon.input("input", shape: {nil, 1})
Expand Down Expand Up @@ -34,8 +36,7 @@ defmodule Power do
Stream.unfold(
Nx.Random.key(:erlang.system_time()),
fn key ->
# Batch size of 32
{x, next_key} = Nx.Random.uniform(key, -10, 10, shape: {32, 1}, type: {:f, 32})
{x, next_key} = Nx.Random.uniform(key, -10, 10, shape: {@batch_size, 1}, type: {:f, 32})
{{x, {Nx.pow(x, 2), Nx.pow(x, 3)}}, next_key}
end
)
Expand Down
22 changes: 15 additions & 7 deletions examples/generative/fashionmnist_autoencoder.exs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,15 @@ Mix.install([
defmodule FashionMNIST do
require Axon

@batch_size 32
@image_channels 1
@image_side_pixels 28
@channel_value_max 255

defmodule Autoencoder do
@image_channels 1
@image_side_pixels 28

defp encoder(x, latent_dim) do
x
|> Axon.flatten()
Expand All @@ -17,8 +25,8 @@ defmodule FashionMNIST do

defp decoder(x) do
x
|> Axon.dense(784, activation: :sigmoid)
|> Axon.reshape({:batch, 1, 28, 28})
|> Axon.dense(@image_side_pixels**2, activation: :sigmoid)
|> Axon.reshape({:batch, @image_channels, @image_side_pixels, @image_side_pixels})
end

def build_model(input_shape, latent_dim) do
Expand All @@ -31,9 +39,9 @@ defmodule FashionMNIST do
defp transform_images({bin, type, shape}) do
bin
|> Nx.from_binary(type)
|> Nx.reshape({elem(shape, 0), 1, 28, 28})
|> Nx.divide(255.0)
|> Nx.to_batched(32)
|> Nx.reshape({elem(shape, 0), @image_channels, @image_side_pixels, @image_side_pixels})
|> Nx.divide(@channel_value_max)
|> Nx.to_batched(@batch_size)
end

defp train_model(model, train_images, epochs) do
Expand All @@ -48,15 +56,15 @@ defmodule FashionMNIST do

train_images = transform_images(images)

model = Autoencoder.build_model({nil, 1, 28, 28}, 64) |> IO.inspect()
model = Autoencoder.build_model({nil, @image_channels, @image_side_pixels, @image_side_pixels}, 64) |> IO.inspect()

model_state = train_model(model, train_images, 5)

sample_image =
train_images
|> Enum.fetch!(0)
|> Nx.slice_along_axis(0, 1)
|> Nx.reshape({1, 1, 28, 28})
|> Nx.reshape({1, @image_channels, @image_side_pixels, @image_side_pixels})

sample_image |> Nx.to_heatmap() |> IO.inspect()

Expand Down
23 changes: 14 additions & 9 deletions examples/generative/mnist_gan.exs
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,17 @@ defmodule MNISTGAN do
alias Axon.Loop.State
import Nx.Defn

@batch_size 32
@image_channels 1
@image_side_pixels 28
@channel_value_max 255

defp transform_images({bin, type, shape}) do
bin
|> Nx.from_binary(type)
|> Nx.reshape({elem(shape, 0), 1, 28, 28})
|> Nx.divide(255.0)
|> Nx.to_batched(32)
|> Nx.reshape({elem(shape, 0), @image_channels, @image_side_pixels, @image_side_pixels})
|> Nx.divide(@channel_value_max)
|> Nx.to_batched(@batch_size)
end

defp build_generator(z_dim) do
Expand All @@ -32,9 +37,9 @@ defmodule MNISTGAN do
|> Axon.dense(1024)
|> Axon.leaky_relu(alpha: 0.9)
|> Axon.batch_norm()
|> Axon.dense(784)
|> Axon.dense(@image_side_pixels**2)
|> Axon.tanh()
|> Axon.reshape({:batch, 28, 28, 1})
|> Axon.reshape({:batch, @image_side_pixels, @image_side_pixels, @image_channels})
end

defp build_discriminator(input_shape) do
Expand Down Expand Up @@ -80,9 +85,9 @@ defmodule MNISTGAN do
g_params = state[:generator][:model_state]

# Update D
fake_labels = Nx.iota({32, 2}, axis: 1)
fake_labels = Nx.iota({@batch_size, 2}, axis: 1)
real_labels = Nx.reverse(fake_labels)
{noise, random_next_key} = Nx.Random.normal(state[:random_key], shape: {32, 100})
{noise, random_next_key} = Nx.Random.normal(state[:random_key], shape: {@batch_size, 100})

{d_loss, d_grads} =
value_and_grad(d_params, fn params ->
Expand Down Expand Up @@ -162,7 +167,7 @@ defmodule MNISTGAN do
preds = Axon.predict(model, pstate[:generator][:model_state], noise)

preds
|> Nx.reshape({batch_size, 28, 28})
|> Nx.reshape({batch_size, @image_side_pixels, @image_side_pixels})
|> Nx.to_heatmap()
|> IO.inspect()

Expand All @@ -174,7 +179,7 @@ defmodule MNISTGAN do
train_images = transform_images(images)

generator = build_generator(100)
discriminator = build_discriminator({nil, 28, 28, 1})
discriminator = build_discriminator({nil, @image_side_pixels, @image_side_pixels, @image_channels})

discriminator
|> train_loop(generator)
Expand Down
14 changes: 9 additions & 5 deletions examples/vision/cifar10.exs
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,27 @@ Mix.install([
defmodule Cifar do
require Axon

@batch_size 32
@channel_value_max 255
@label_values Enum.to_list(0..9)

defp transform_images({bin, type, shape}) do
bin
|> Nx.from_binary(type)
|> Nx.reshape(shape, names: [:count, :channels, :width, :height])
# Move channels to last position to match what conv layer expects
|> Nx.transpose(axes: [:count, :width, :height, :channels])
|> Nx.divide(255.0)
|> Nx.to_batched(32)
|> Nx.divide(@channel_value_max)
|> Nx.to_batched(@batch_size)
|> Enum.split(1500)
end

defp transform_labels({bin, type, _}) do
bin
|> Nx.from_binary(type)
|> Nx.new_axis(-1)
|> Nx.equal(Nx.tensor(Enum.to_list(0..9)))
|> Nx.to_batched(32)
|> Nx.equal(Nx.tensor(@label_values))
|> Nx.to_batched(@batch_size)
|> Enum.split(1500)
end

Expand All @@ -39,7 +43,7 @@ defmodule Cifar do
|> Axon.flatten()
|> Axon.dense(64, activation: :relu)
|> Axon.dropout(rate: 0.5)
|> Axon.dense(10, activation: :softmax)
|> Axon.dense(length(@label_values), activation: :softmax)
end

defp train_model(model, train_images, train_labels, epochs) do
Expand Down
11 changes: 7 additions & 4 deletions examples/vision/cnn_image_denoising.exs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ defmodule MnistDenoising do
@noise_factor 0.4
@batch_size 32
@epochs 25
@image_channels 1
@image_side_pixels 28
@channel_value_max 255

def run do
{images, _} = Scidata.MNIST.download()
Expand All @@ -26,7 +29,7 @@ defmodule MnistDenoising do
noisy_train_images |> Enum.take(1) |> hd() |> display_image()

# Train with noisy images as input and train images as targets
model = build_model({nil, 1, 28, 28})
model = build_model({nil, @image_channels, @image_side_pixels, @image_side_pixels})

model_state =
model
Expand All @@ -46,8 +49,8 @@ defmodule MnistDenoising do
defp transform_images({bin, type, shape}) do
bin
|> Nx.from_binary(type)
|> Nx.reshape({elem(shape, 0), 28, 28, 1})
|> Nx.divide(255.0)
|> Nx.reshape({elem(shape, 0), @image_side_pixels, @image_side_pixels, @image_channels})
|> Nx.divide(@channel_value_max)
|> Nx.to_batched_list(@batch_size)
# Test split
|> Enum.split(1750)
Expand All @@ -63,7 +66,7 @@ defmodule MnistDenoising do
defp display_image(images) do
images
|> Nx.slice_along_axis(0, 1)
|> Nx.reshape({28, 28, 1})
|> Nx.reshape({@image_side_pixels, @image_side_pixels, @image_channels})
|> Nx.to_heatmap()
|> IO.inspect()
end
Expand Down
10 changes: 7 additions & 3 deletions examples/vision/horses_or_humans.exs
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,14 @@ defmodule HorsesOrHumans do
# or you can use Req to download and extract the zip file and iterate
# over the resulting data
@directories "examples/vision/{horses,humans}/*"
@batch_size 32
@image_channels 4
@image_side_pixels 300
@channel_value_max 255

def data() do
Path.wildcard(@directories)
|> Stream.chunk_every(32, 32, :discard)
|> Stream.chunk_every(@batch_size, @batch_size, :discard)
|> Task.async_stream(fn batch ->
{inp, labels} = batch |> Enum.map(&parse_png/1) |> Enum.unzip()
{Nx.stack(inp), Nx.stack(labels)}
Expand All @@ -29,7 +33,7 @@ defmodule HorsesOrHumans do

defnp augment(inp) do
# Normalize
inp = inp / 255.0
inp = inp / @channel_value_max

# For now just a random flip
if Nx.random_uniform({}) > 0.5 do
Expand Down Expand Up @@ -78,7 +82,7 @@ defmodule HorsesOrHumans do
end

def run() do
model = build_model({nil, 300, 300, 4}) |> IO.inspect()
model = build_model({nil, @image_side_pixels, @image_side_pixels, @image_channels}) |> IO.inspect()
optimizer = Polaris.Optimizers.adam(learning_rate: 1.0e-4)
centralized_optimizer = Polaris.Updates.compose(Polaris.Updates.centralize(), optimizer)

Expand Down
19 changes: 12 additions & 7 deletions examples/vision/mnist.exs
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,17 @@ Mix.install([
defmodule Mnist do
require Axon

@batch_size 32
@image_side_pixels 28
@channel_value_max 255
@label_values Enum.to_list(0..9)

defp transform_images({bin, type, shape}) do
bin
|> Nx.from_binary(type)
|> Nx.reshape({elem(shape, 0), 784})
|> Nx.divide(255.0)
|> Nx.to_batched(32)
|> Nx.reshape({elem(shape, 0), @image_side_pixels**2})
|> Nx.divide(@channel_value_max)
|> Nx.to_batched(@batch_size)
# Test split
|> Enum.split(1750)
end
Expand All @@ -23,8 +28,8 @@ defmodule Mnist do
bin
|> Nx.from_binary(type)
|> Nx.new_axis(-1)
|> Nx.equal(Nx.tensor(Enum.to_list(0..9)))
|> Nx.to_batched(32)
|> Nx.equal(Nx.tensor(@label_values))
|> Nx.to_batched(@batch_size)
# Test split
|> Enum.split(1750)
end
Expand All @@ -33,7 +38,7 @@ defmodule Mnist do
Axon.input("input", shape: input_shape)
|> Axon.dense(128, activation: :relu)
|> Axon.dropout()
|> Axon.dense(10, activation: :softmax)
|> Axon.dense(length(@label_values), activation: :softmax)
end

defp train_model(model, train_images, train_labels, epochs) do
Expand All @@ -56,7 +61,7 @@ defmodule Mnist do
{train_images, test_images} = transform_images(images)
{train_labels, test_labels} = transform_labels(labels)

model = build_model({nil, 784}) |> IO.inspect()
model = build_model({nil, @image_side_pixels**2}) |> IO.inspect()

IO.write("\n\n Training Model \n\n")

Expand Down

0 comments on commit cd75977

Please sign in to comment.