Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add :type option to load model under specific precision #311

Merged
merged 1 commit into from
Dec 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 46 additions & 5 deletions lib/bumblebee.ex
Original file line number Diff line number Diff line change
Expand Up @@ -317,16 +317,36 @@ defmodule Bumblebee do
@doc """
Builds an `Axon` model according to the given specification.

## Options

* `:type` - either a type or `Axon.MixedPrecision` policy to apply
to the model

## Example

spec = Bumblebee.configure(Bumblebee.Vision.ResNet, architecture: :base, embedding_size: 128)
model = Bumblebee.build_model(spec)

"""
@doc type: :model
@spec build_model(Bumblebee.ModelSpec.t()) :: Axon.t()
def build_model(%module{} = spec) do
module.model(spec)
@spec build_model(Bumblebee.ModelSpec.t(), keyword()) :: Axon.t()
def build_model(%module{} = spec, opts \\ []) do
opts = Keyword.validate!(opts, [:type])

model = module.model(spec)

case opts[:type] do
nil ->
model

%Axon.MixedPrecision.Policy{} = policy ->
Axon.MixedPrecision.apply_policy(model, policy)

type ->
type = Nx.Type.normalize!(type)
policy = Axon.MixedPrecision.create_policy(params: type, compute: type, output: type)
Axon.MixedPrecision.apply_policy(model, policy)
end
end

@doc """
Expand Down Expand Up @@ -446,6 +466,22 @@ defmodule Bumblebee do
The model is downloaded and cached on your disk, use `cache_dir/0` to
find the location.

## Parameters precision

On GPUs computations that use numeric type of lower precision can
be faster and use less memory, while still providing valid results.
You can configure the model to use particular type by passing the
`:type` option, such as `:bf16`.

Some repositories have multiple variants of the parameter files
with different numeric types. The variant is usually indicated in
the file extension and you can load a particular file by specifying
`:params_variant`, or `:params_filename`. Note however that this
does not determine the numeric type used for inference. The file
type is relevant in context of download bandwidth and disk space.
If you want to use a lower precision for inference, make sure to
also specify `:type`.

## Options

* `:spec` - the model specification to use when building the model.
Expand All @@ -470,6 +506,10 @@ defmodule Bumblebee do
* `:backend` - the backend to allocate the tensors on. It is either
an atom or a tuple in the shape `{backend, options}`

* `:type` - either a type or `Axon.MixedPrecision` policy to apply
to the model. Passing this option automatically casts parameters
to the desired type

## Examples

By default the model type is inferred from configuration, so loading
Expand Down Expand Up @@ -502,13 +542,14 @@ defmodule Bumblebee do
:architecture,
:params_variant,
:params_filename,
:log_params_diff,
:backend,
:log_params_diff
:type
])

with {:ok, repo_files} <- get_repo_files(repository),
{:ok, spec} <- maybe_load_model_spec(opts, repository, repo_files),
model <- build_model(spec),
model <- build_model(spec, Keyword.take(opts, [:type])),
{:ok, params} <- load_params(spec, model, repository, repo_files, opts) do
{:ok, %{model: model, params: params, spec: spec}}
end
Expand Down
10 changes: 10 additions & 0 deletions lib/bumblebee/conversion/pytorch.ex
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ defmodule Bumblebee.Conversion.PyTorch do

case verify_param_shape(param_expr, value) do
:ok ->
value = ensure_type(param_expr, value)
{value, diff}

{:error, expected, actual} ->
Expand Down Expand Up @@ -486,6 +487,15 @@ defmodule Bumblebee.Conversion.PyTorch do
Utils.Nx.map(expr, &Nx.shape/1)
end

defp ensure_type(param_expr, value) do
Utils.Nx.zip_with(param_expr, value, fn expr, tensor ->
case {Nx.type(expr), Nx.type(tensor)} do
{type, type} -> tensor
{expected, _actual} -> Nx.as_type(tensor, expected)
end
end)
end

defp unflatten_leading(tensor, axis_size) do
shape =
tensor
Expand Down
40 changes: 40 additions & 0 deletions lib/bumblebee/utils/nx.ex
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,46 @@ defmodule Bumblebee.Utils.Nx do
|> elem(0)
end

@doc """
Recursively zips the given containers with the given function.
"""
@spec zip_with(
tensor_or_container,
tensor_or_container,
(Nx.Tensor.t(), Nx.Tensor.t() -> term())
) :: tensor_or_container
when tensor_or_container: Nx.Tensor.t() | Nx.Container.t()
def zip_with(left, right, fun)

def zip_with(%Nx.Tensor{} = left, %Nx.Tensor{} = right, fun) do
fun.(left, right)
end

def zip_with(left, right, fun) do
right_items =
right
|> Nx.Container.reduce([], fn item, acc -> [item | acc] end)
|> Enum.reverse()

case Nx.Container.traverse(left, right_items, &recur_zip_with(&1, &2, fun)) do
{result, []} ->
result

{_result, _leftover} ->
raise ArgumentError, "unable to merge arguments with incompatible structure"
end
end

defp recur_zip_with(left, [right | right_items], fun) do
case {left, right} do
{%Nx.Tensor{} = left, %Nx.Tensor{} = right} ->
{fun.(left, right), right_items}

{left, right} ->
{recur_zip_with(left, right, fun), right_items}
end
end

@doc """
Returns the underlying tensor as a list.

Expand Down
18 changes: 18 additions & 0 deletions test/bumblebee_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -63,5 +63,23 @@ defmodule BumblebeeTest do
)
end
end

test "passing :type casts params accordingly" do
assert {:ok, %{params: params}} =
Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-GPT2Model"},
type: :bf16
)

assert Nx.type(params["decoder.blocks.0.ffn.output"]["kernel"]) == {:bf, 16}
assert Nx.type(params["decoder.blocks.0.ffn.output"]["bias"]) == {:bf, 16}

assert {:ok, %{params: params}} =
Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-GPT2Model"},
type: Axon.MixedPrecision.create_policy(params: :f16)
)

assert Nx.type(params["decoder.blocks.0.ffn.output"]["kernel"]) == {:f, 16}
assert Nx.type(params["decoder.blocks.0.ffn.output"]["bias"]) == {:f, 16}
end
end
end