diff --git a/lib/mix/tasks/train_model.ex b/lib/mix/tasks/train_model.ex index 14534e3..1fd4817 100644 --- a/lib/mix/tasks/train_model.ex +++ b/lib/mix/tasks/train_model.ex @@ -128,11 +128,11 @@ defmodule Mix.Tasks.TrainModel do predict_from_data_frame(model, validate_df, iteration_range: {0, model.best_iteration}) overall = - (validate_df - |> DF.mutate(regression: time + %Duration{value: 1_000, precision: :millisecond} * ^pred) - |> overall_accuracy(:time, :arrival_time, :regression, &accuracy/1))[ - :accuracy - ][0] + validate_df + |> DF.mutate(regression: time + %Duration{value: 1_000, precision: :millisecond} * ^pred) + |> overall_accuracy(:time, :arrival_time, :regression, &accuracy/1) + |> Access.get(:accuracy) + |> Access.get(0) IO.puts("Overall accuracy: #{overall}%") diff --git a/lib/ride_along/eta_calculator/training.ex b/lib/ride_along/eta_calculator/training.ex index d215779..ed51e2e 100644 --- a/lib/ride_along/eta_calculator/training.ex +++ b/lib/ride_along/eta_calculator/training.ex @@ -71,14 +71,14 @@ defmodule RideAlong.EtaCalculator.Training do for_result = for slice <- 0..slices do - model - |> Model.predict_from_tensor( + tensor = df |> DF.select(Model.feature_names()) |> DF.slice((slice * slice_size)..((slice + 1) * slice_size - 1)) - |> Nx.stack(axis: 1), - opts - ) + |> Nx.stack(axis: 1) + + model + |> Model.predict_from_tensor(tensor, opts) |> Series.from_tensor() end @@ -279,11 +279,11 @@ defmodule RideAlong.EtaCalculator.Training do pred = predict_from_data_frame(state.booster, validate_df) overall = - (validate_df - |> DF.mutate(regression: time + %Duration{value: 1_000, precision: :millisecond} * ^pred) - |> overall_accuracy(:time, :arrival_time, :regression, &accuracy/1))[ - :accuracy - ][0] + validate_df + |> DF.mutate(regression: time + %Duration{value: 1_000, precision: :millisecond} * ^pred) + |> overall_accuracy(:time, :arrival_time, :regression, &accuracy/1) + |> Access.get(:accuracy) + |> Access.get(0) if verbose? do IO.puts("Iteration #{state.iteration}: #{overall}%") diff --git a/mix.exs b/mix.exs index 49b1d69..a97595d 100644 --- a/mix.exs +++ b/mix.exs @@ -65,7 +65,7 @@ defmodule RideAlong.MixProject do {:req, "~> 0.4"}, {:sobelow, "~> 0.13.0", only: :dev, runtime: false}, {:stream_data, "~> 1.0", only: :test}, - {:styler, "~> 1.1"}, + {:styler, "~> 1.2"}, {:tailwind, "~> 0.2", runtime: Mix.env() == :dev}, {:tds, "~> 2.3"}, {:telemetry_metrics, "~> 1.0"}, diff --git a/mix.lock b/mix.lock index 3de7c32..2ad2720 100644 --- a/mix.lock +++ b/mix.lock @@ -70,7 +70,7 @@ "sobelow": {:hex, :sobelow, "0.13.0", "218afe9075904793f5c64b8837cc356e493d88fddde126a463839351870b8d1e", [:mix], [{:jason, "~> 1.0", [hex: :jason, repo: "hexpm", optional: false]}], "hexpm", "cd6e9026b85fc35d7529da14f95e85a078d9dd1907a9097b3ba6ac7ebbe34a0d"}, "ssl_verify_fun": {:hex, :ssl_verify_fun, "1.1.7", "354c321cf377240c7b8716899e182ce4890c5938111a1296add3ec74cf1715df", [:make, :mix, :rebar3], [], "hexpm", "fe4c190e8f37401d30167c8c405eda19469f34577987c76dde613e838bbc67f8"}, "stream_data": {:hex, :stream_data, "1.1.2", "05499eaec0443349ff877aaabc6e194e82bda6799b9ce6aaa1aadac15a9fdb4d", [:mix], [], "hexpm", "129558d2c77cbc1eb2f4747acbbea79e181a5da51108457000020a906813a1a9"}, - "styler": {:hex, :styler, "1.1.2", "d5b14cd4f8f7cc45624d9485cd0edb277ec92583b118409cfcbcb7c78efa5f4b", [:mix], [], "hexpm", "b46edab1f129d0c839d426755e172cf92118e5fac877456d074156b335f1f80b"}, + "styler": {:hex, :styler, "1.2.1", "28f9e3d4b065c22575c56b8ae03d05188add1b21bec5ae664fc1551e2dfcc41b", [:mix], [], "hexpm", "71dc33980e530d21ca54db9c2075e646faa6e7b744a9d4a3dfb0ff01f56595f0"}, "table": {:hex, :table, "0.1.2", "87ad1125f5b70c5dea0307aa633194083eb5182ec537efc94e96af08937e14a8", [:mix], [], "hexpm", "7e99bc7efef806315c7e65640724bf165c3061cdc5d854060f74468367065029"}, "table_rex": {:hex, :table_rex, "4.0.0", "3c613a68ebdc6d4d1e731bc973c233500974ec3993c99fcdabb210407b90959b", [:mix], [], "hexpm", "c35c4d5612ca49ebb0344ea10387da4d2afe278387d4019e4d8111e815df8f55"}, "tailwind": {:hex, :tailwind, "0.2.4", "5706ec47182d4e7045901302bf3a333e80f3d1af65c442ba9a9eed152fb26c2e", [:mix], [{:castore, ">= 0.0.0", [hex: :castore, repo: "hexpm", optional: false]}], "hexpm", "c6e4a82b8727bab593700c998a4d98cf3d8025678bfde059aed71d0000c3e463"},