Skip to content

Commit

Permalink
style
Browse files Browse the repository at this point in the history
  • Loading branch information
grzuy committed Nov 17, 2023
1 parent 2c73bda commit 519af74
Showing 1 changed file with 31 additions and 34 deletions.
65 changes: 31 additions & 34 deletions examples/linear_regression.exs
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,14 @@ defmodule LinearRegression do
import Nx.Defn

@epochs 100
@step 0.001 # Sometimes also called learning rate
@gradient_step_size 0.001 # Sometimes also called "learning rate"

defn initial_random_params do
{m, new_key} =
Nx.Random.key(42)
|> Nx.Random.normal(0.0, 0.1, shape: {1, 1})

{b, _new_key} =
new_key
|> Nx.Random.normal(0.0, 0.1, shape: {1})

{m, b}
end

def train(params, linear_fn) do
def fit(linear_fn) do
data =
Stream.repeatedly(fn -> for _ <- 1..32, do: :rand.uniform() * 10 end)
|> Stream.map(fn x -> Enum.zip(x, Enum.map(x, linear_fn)) end)

for _ <- 1..@epochs, reduce: params do
for _ <- 1..@epochs, reduce: initial_random_params() do
acc ->
data
|> Enum.take(200)
Expand All @@ -40,52 +28,61 @@ defmodule LinearRegression do
end
end

defn predict({m, b}, input) do
defnp initial_random_params do
{m, new_key} =
Nx.Random.key(42)
|> Nx.Random.normal(0.0, 0.1, shape: {1, 1})

{b, _new_key} =
new_key
|> Nx.Random.normal(0.0, 0.1, shape: {1})

{m, b}
end

defnp evaluate({m, b}, input) do
Nx.dot(input, m) + b
end

defn mse_loss(params, input, target) do
target - predict(params, input)
defnp mean_squared_error(params, input, target) do
target - evaluate(params, input)
|> Nx.pow(2)
|> Nx.mean()
end

defn update({m, b} = params, input, target) do
defnp update({m, b} = params, input, target) do
{grad_m, grad_b} =
params
|> grad(&mse_loss(&1, input, target))
|> grad(&mean_squared_error(&1, input, target))

{
m - grad_m * @step,
b - grad_b * @step
m - grad_m * @gradient_step_size,
b - grad_b * @gradient_step_size
}
end
end

Nx.default_backend(Candlex.Backend)

initial_params = LinearRegression.initial_random_params()
m = :rand.normal(0.0, 10.0)
b = :rand.normal(0.0, 5.0)
IO.puts("Target m: #{m} Target b: #{b}\n")

linear_fn = fn x -> m * x + b end

# These will be very close to the above coefficients
{time, {trained_m, trained_b}} = :timer.tc(LinearRegression, :train, [initial_params, linear_fn])
# These should be very close to the above coefficients
{time, {fitted_m, fitted_b}} = :timer.tc(LinearRegression, :fit, [fn x -> m * x + b end])

trained_m =
trained_m
fitted_m =
fitted_m
|> Nx.squeeze()
|> Nx.backend_transfer()
|> Nx.to_number()

trained_b =
trained_b
fitted_b =
fitted_b
|> Nx.squeeze()
|> Nx.backend_transfer()
|> Nx.to_number()

IO.puts("Trained in #{time / 1_000_000} sec.")
IO.puts("Trained m: #{trained_m} Trained b: #{trained_b}\n")
IO.puts("Accuracy m: #{m - trained_m} Accuracy b: #{b - trained_b}")
IO.puts("Fitted in #{time / 1_000_000} sec.")
IO.puts("Fitted m: #{fitted_m} Fitted b: #{fitted_b}\n")
IO.puts("Accuracy m: #{m - fitted_m} Accuracy b: #{b - fitted_b}")

0 comments on commit 519af74

Please sign in to comment.