From 1e4280205fc19511cd490a34dc04d8d7b1d00d02 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonatan=20K=C5=82osko?= Date: Sat, 16 Dec 2023 00:03:09 +0700 Subject: [PATCH] Add :type option to load model under specific precision --- lib/bumblebee.ex | 51 ++++++++++++++++++++++++++--- lib/bumblebee/conversion/pytorch.ex | 10 ++++++ lib/bumblebee/utils/nx.ex | 40 ++++++++++++++++++++++ test/bumblebee_test.exs | 18 ++++++++++ 4 files changed, 114 insertions(+), 5 deletions(-) diff --git a/lib/bumblebee.ex b/lib/bumblebee.ex index e63d504a..4c150917 100644 --- a/lib/bumblebee.ex +++ b/lib/bumblebee.ex @@ -317,6 +317,11 @@ defmodule Bumblebee do @doc """ Builds an `Axon` model according to the given specification. + ## Options + + * `:type` - either a type or `Axon.MixedPrecision` policy to apply + to the model + ## Example spec = Bumblebee.configure(Bumblebee.Vision.ResNet, architecture: :base, embedding_size: 128) @@ -324,9 +329,24 @@ defmodule Bumblebee do """ @doc type: :model - @spec build_model(Bumblebee.ModelSpec.t()) :: Axon.t() - def build_model(%module{} = spec) do - module.model(spec) + @spec build_model(Bumblebee.ModelSpec.t(), keyword()) :: Axon.t() + def build_model(%module{} = spec, opts \\ []) do + opts = Keyword.validate!(opts, [:type]) + + model = module.model(spec) + + case opts[:type] do + nil -> + model + + %Axon.MixedPrecision.Policy{} = policy -> + Axon.MixedPrecision.apply_policy(model, policy) + + type -> + type = Nx.Type.normalize!(type) + policy = Axon.MixedPrecision.create_policy(params: type, compute: type, output: type) + Axon.MixedPrecision.apply_policy(model, policy) + end end @doc """ @@ -446,6 +466,22 @@ defmodule Bumblebee do The model is downloaded and cached on your disk, use `cache_dir/0` to find the location. + ## Parameters precision + + On GPUs computations that use numeric type of lower precision can + be faster and use less memory, while still providing valid results. + You can configure the model to use particular type by passing the + `:type` option, such as `:bf16`. + + Some repositories have multiple variants of the parameter files + with different numeric types. The variant is usually indicated in + the file extension and you can load a particular file by specifying + `:params_variant`, or `:params_filename`. Note however that this + does not determine the numeric type used for inference. The file + type is relevant in context of download bandwidth and disk space. + If you want to use a lower precision for inference, make sure to + also specify `:type`. + ## Options * `:spec` - the model specification to use when building the model. @@ -470,6 +506,10 @@ defmodule Bumblebee do * `:backend` - the backend to allocate the tensors on. It is either an atom or a tuple in the shape `{backend, options}` + * `:type` - either a type or `Axon.MixedPrecision` policy to apply + to the model. Passing this option automatically casts parameters + to the desired type + ## Examples By default the model type is inferred from configuration, so loading @@ -502,13 +542,14 @@ defmodule Bumblebee do :architecture, :params_variant, :params_filename, + :log_params_diff, :backend, - :log_params_diff + :type ]) with {:ok, repo_files} <- get_repo_files(repository), {:ok, spec} <- maybe_load_model_spec(opts, repository, repo_files), - model <- build_model(spec), + model <- build_model(spec, Keyword.take(opts, [:type])), {:ok, params} <- load_params(spec, model, repository, repo_files, opts) do {:ok, %{model: model, params: params, spec: spec}} end diff --git a/lib/bumblebee/conversion/pytorch.ex b/lib/bumblebee/conversion/pytorch.ex index 9ac46a10..60b49213 100644 --- a/lib/bumblebee/conversion/pytorch.ex +++ b/lib/bumblebee/conversion/pytorch.ex @@ -145,6 +145,7 @@ defmodule Bumblebee.Conversion.PyTorch do case verify_param_shape(param_expr, value) do :ok -> + value = ensure_type(param_expr, value) {value, diff} {:error, expected, actual} -> @@ -486,6 +487,15 @@ defmodule Bumblebee.Conversion.PyTorch do Utils.Nx.map(expr, &Nx.shape/1) end + defp ensure_type(param_expr, value) do + Utils.Nx.zip_with(param_expr, value, fn expr, tensor -> + case {Nx.type(expr), Nx.type(tensor)} do + {type, type} -> tensor + {expected, _actual} -> Nx.as_type(tensor, expected) + end + end) + end + defp unflatten_leading(tensor, axis_size) do shape = tensor diff --git a/lib/bumblebee/utils/nx.ex b/lib/bumblebee/utils/nx.ex index 113f614b..49502eef 100644 --- a/lib/bumblebee/utils/nx.ex +++ b/lib/bumblebee/utils/nx.ex @@ -23,6 +23,46 @@ defmodule Bumblebee.Utils.Nx do |> elem(0) end + @doc """ + Recursively zips the given containers with the given function. + """ + @spec zip_with( + tensor_or_container, + tensor_or_container, + (Nx.Tensor.t(), Nx.Tensor.t() -> term()) + ) :: tensor_or_container + when tensor_or_container: Nx.Tensor.t() | Nx.Container.t() + def zip_with(left, right, fun) + + def zip_with(%Nx.Tensor{} = left, %Nx.Tensor{} = right, fun) do + fun.(left, right) + end + + def zip_with(left, right, fun) do + right_items = + right + |> Nx.Container.reduce([], fn item, acc -> [item | acc] end) + |> Enum.reverse() + + case Nx.Container.traverse(left, right_items, &recur_zip_with(&1, &2, fun)) do + {result, []} -> + result + + {_result, _leftover} -> + raise ArgumentError, "unable to merge arguments with incompatible structure" + end + end + + defp recur_zip_with(left, [right | right_items], fun) do + case {left, right} do + {%Nx.Tensor{} = left, %Nx.Tensor{} = right} -> + {fun.(left, right), right_items} + + {left, right} -> + {recur_zip_with(left, right, fun), right_items} + end + end + @doc """ Returns the underlying tensor as a list. diff --git a/test/bumblebee_test.exs b/test/bumblebee_test.exs index e86dd7d2..a9d7e1b4 100644 --- a/test/bumblebee_test.exs +++ b/test/bumblebee_test.exs @@ -63,5 +63,23 @@ defmodule BumblebeeTest do ) end end + + test "passing :type casts params accordingly" do + assert {:ok, %{params: params}} = + Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-GPT2Model"}, + type: :bf16 + ) + + assert Nx.type(params["decoder.blocks.0.ffn.output"]["kernel"]) == {:bf, 16} + assert Nx.type(params["decoder.blocks.0.ffn.output"]["bias"]) == {:bf, 16} + + assert {:ok, %{params: params}} = + Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-GPT2Model"}, + type: Axon.MixedPrecision.create_policy(params: :f16) + ) + + assert Nx.type(params["decoder.blocks.0.ffn.output"]["kernel"]) == {:f, 16} + assert Nx.type(params["decoder.blocks.0.ffn.output"]["bias"]) == {:f, 16} + end end end