Skip to content

Commit

Permalink
Add test that exports to MLIR a small sharded Llama model (#220)
Browse files Browse the repository at this point in the history
Add test that exports to MLIR a small sharded Llama model

The decode step requires exporting in non-strict torch mode due to
pytorch/pytorch#135061

This export required to extend the registration functionality of our custom
tensor types by provinding `flatten_with_keys_fn`. This is also required
to bump the PyTorch version to >=2.4 for other export tests.

The export to MLIR fails with
TypeError: Unsupported torch type conversion for
!torch.vtensor<[3,1,7],complex<f32>>
Needs further debugging.
Detailed error
[here](https://github.com/user-attachments/files/17163454/sharded-llama-export-to-mlir-failure.txt).
  • Loading branch information
sogartar authored Sep 27, 2024
1 parent 86e0711 commit fca29f4
Show file tree
Hide file tree
Showing 3 changed files with 207 additions and 96 deletions.
9 changes: 8 additions & 1 deletion sharktank/sharktank/ops/sharded_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import itertools
from numbers import Number
import math
import numpy as np

from ..types import (
AnyTensor,
Expand Down Expand Up @@ -365,6 +364,14 @@ def elementwise_binary_replicated_lhs_unsharded_rhs(
return elementwise(operator, x, y_replicated, *args, **kwargs)


@elementwise.override(Tensor, ReplicatedTensor)
def elementwise_binary_replicated_lhs_unsharded_rhs(
operator, x: Tensor, y: ReplicatedTensor, *args, **kwargs
):
x_replicated = reshard_like(x, like=y)
return elementwise(operator, x_replicated, y, *args, **kwargs)


# Embedding Lookup
@embedding_lookup.override(ReplicatedTensor, ReplicatedTensor)
def embedding_lookup_default(
Expand Down
20 changes: 19 additions & 1 deletion sharktank/sharktank/types/tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

import torch
from torch import Tensor
from torch.utils._pytree import register_pytree_node
from torch.utils._pytree import register_pytree_node, SequenceKey
from ..utils.math import ceildiv
from shark_turbine.aot import (
ExternalTensorTrait,
Expand Down Expand Up @@ -1286,10 +1286,16 @@ def unflatten_defult_primitive_tensor(
return DefaultPrimitiveTensor(data=values_as_list[0], name=ctx["name"])


def flatten_with_keys_default_primitive_tensor(t: DefaultPrimitiveTensor):
values, context = flatten_default_primitive_tensor(t)
return [(SequenceKey(i), v) for i, v in enumerate(values)], context


register_pytree_node(
DefaultPrimitiveTensor,
flatten_fn=flatten_default_primitive_tensor,
unflatten_fn=unflatten_defult_primitive_tensor,
flatten_with_keys_fn=flatten_with_keys_default_primitive_tensor,
)


Expand All @@ -1307,10 +1313,16 @@ def unflatten_split_primitive_tensor(
)


def flatten_with_keys_split_primitive_tensor(t: SplitPrimitiveTensor):
values, context = flatten_split_primitive_tensor(t)
return [(SequenceKey(i), v) for i, v in enumerate(values)], context


register_pytree_node(
SplitPrimitiveTensor,
flatten_fn=flatten_split_primitive_tensor,
unflatten_fn=unflatten_split_primitive_tensor,
flatten_with_keys_fn=flatten_with_keys_split_primitive_tensor,
)


Expand All @@ -1326,8 +1338,14 @@ def unflatten_replicated_tensor(
return ReplicatedTensor(ts=list(values), name=ctx["name"])


def flatten_with_keys_replicated_tensor(t: ReplicatedTensor):
values, context = flatten_replicated_tensor(t)
return [(SequenceKey(i), v) for i, v in enumerate(values)], context


register_pytree_node(
ReplicatedTensor,
flatten_fn=flatten_replicated_tensor,
unflatten_fn=unflatten_replicated_tensor,
flatten_with_keys_fn=flatten_with_keys_replicated_tensor,
)
Loading

0 comments on commit fca29f4

Please sign in to comment.