Skip to content

Commit

Permalink
fixes confusing source of value max for channel input normalization
Browse files Browse the repository at this point in the history
  • Loading branch information
grzuy committed Jul 19, 2023
1 parent c4771cb commit 13f6662
Show file tree
Hide file tree
Showing 6 changed files with 14 additions and 9 deletions.
3 changes: 2 additions & 1 deletion examples/generative/fashionmnist_autoencoder.exs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ defmodule FashionMNIST do
@batch_size 32
@image_channels 1
@image_side_pixels 28
@channel_value_max 255

defmodule Autoencoder do
@image_channels 1
Expand Down Expand Up @@ -39,7 +40,7 @@ defmodule FashionMNIST do
bin
|> Nx.from_binary(type)
|> Nx.reshape({elem(shape, 0), @image_channels, @image_side_pixels, @image_side_pixels})
|> Nx.divide(Nx.Constants.max(type))
|> Nx.divide(@channel_value_max)
|> Nx.to_batched(@batch_size)
end

Expand Down
3 changes: 2 additions & 1 deletion examples/generative/mnist_gan.exs
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@ defmodule MNISTGAN do
@batch_size 32
@image_channels 1
@image_side_pixels 28
@channel_value_max 255

defp transform_images({bin, type, shape}) do
bin
|> Nx.from_binary(type)
|> Nx.reshape({elem(shape, 0), @image_channels, @image_side_pixels, @image_side_pixels})
|> Nx.divide(Nx.Constants.max(type))
|> Nx.divide(@channel_value_max)
|> Nx.to_batched(@batch_size)
end

Expand Down
3 changes: 2 additions & 1 deletion examples/vision/cifar10.exs
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@ defmodule Cifar do
@batch_size 32
@image_channels 3
@image_side_pixels 32
@channel_value_max 255
@label_values Enum.to_list(0..9)

defp transform_images({bin, type, shape}) do
bin
|> Nx.from_binary(type)
|> Nx.reshape({elem(shape, 0), @image_side_pixels, @image_side_pixels, @image_channels})
|> Nx.divide(Nx.Constants.max(type))
|> Nx.divide(@channel_value_max)
|> Nx.to_batched(@batch_size)
|> Enum.split(1500)
end
Expand Down
3 changes: 2 additions & 1 deletion examples/vision/cnn_image_denoising.exs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ defmodule MnistDenoising do
@epochs 25
@image_channels 1
@image_side_pixels 28
@channel_value_max 255

def run do
{images, _} = Scidata.MNIST.download()
Expand Down Expand Up @@ -49,7 +50,7 @@ defmodule MnistDenoising do
bin
|> Nx.from_binary(type)
|> Nx.reshape({elem(shape, 0), @image_side_pixels, @image_side_pixels, @image_channels})
|> Nx.divide(Nx.Constants.max(type))
|> Nx.divide(@channel_value_max)
|> Nx.to_batched_list(@batch_size)
# Test split
|> Enum.split(1750)
Expand Down
8 changes: 4 additions & 4 deletions examples/vision/horses_or_humans.exs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ defmodule HorsesOrHumans do
@batch_size 32
@image_channels 4
@image_side_pixels 300
@input_type {:u, 8}
@channel_value_max 255

def data() do
Path.wildcard(@directories)
Expand All @@ -33,7 +33,7 @@ defmodule HorsesOrHumans do

defnp augment(inp) do
# Normalize
inp = inp / Nx.Constants.max(@input_type)
inp = inp / @channel_value_max

# For now just a random flip
if Nx.random_uniform({}) > 0.5 do
Expand All @@ -46,8 +46,8 @@ defmodule HorsesOrHumans do
defp parse_png(filename) do
class =
if String.contains?(filename, "horses"),
do: Nx.tensor([1, 0], type: @input_type),
else: Nx.tensor([0, 1], type: @input_type)
do: Nx.tensor([1, 0], type: {:u, 8}),
else: Nx.tensor([0, 1], type: {:u, 8})

{:ok, img} = StbImage.read_file(filename)

Expand Down
3 changes: 2 additions & 1 deletion examples/vision/mnist.exs
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@ defmodule Mnist do

@batch_size 32
@image_side_pixels 28
@channel_value_max 255
@label_values Enum.to_list(0..9)

defp transform_images({bin, type, shape}) do
bin
|> Nx.from_binary(type)
|> Nx.reshape({elem(shape, 0), @image_side_pixels**2})
|> Nx.divide(Nx.Constants.max(type))
|> Nx.divide(@channel_value_max)
|> Nx.to_batched(@batch_size)
# Test split
|> Enum.split(1750)
Expand Down

0 comments on commit 13f6662

Please sign in to comment.