diff --git a/CHANGELOG.md b/CHANGELOG.md
index 7fecfd96..9d68d197 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -10,6 +10,7 @@
>
> please add your unreleased change here.
+- [Feature] Support more generic Torch model inference
- [Improvement] Optimize one-time setup for yacl ot
- [Improvement] Optimize sort performance
diff --git a/docs/reference/pphlo_op_doc.md b/docs/reference/pphlo_op_doc.md
index 69367ba5..74725045 100644
--- a/docs/reference/pphlo_op_doc.md
+++ b/docs/reference/pphlo_op_doc.md
@@ -18,9 +18,9 @@ Effects: MemoryEffects::Effect{}
Attribute | MLIR Type | Description |
-edge_padding_low | ::mlir::DenseIntElementsAttr | 64-bit signless integer elements attribute |
-edge_padding_high | ::mlir::DenseIntElementsAttr | 64-bit signless integer elements attribute |
-interior_padding | ::mlir::DenseIntElementsAttr | 64-bit signless integer elements attribute |
+edge_padding_low | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
+edge_padding_high | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
+interior_padding | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
#### Operands:
@@ -135,9 +135,9 @@ Effects: MemoryEffects::Effect{}
Attribute | MLIR Type | Description |
-window_dimensions | ::mlir::DenseIntElementsAttr | 64-bit signless integer elements attribute |
-window_strides | ::mlir::DenseIntElementsAttr | 64-bit signless integer elements attribute |
-window_dilations | ::mlir::DenseIntElementsAttr | 64-bit signless integer elements attribute |
+window_dimensions | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
+window_strides | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
+window_dilations | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
onehot_index | ::mlir::BoolAttr | bool attribute |
@@ -213,7 +213,7 @@ Effects: MemoryEffects::Effect{}
Attribute | MLIR Type | Description |
-broadcast_dimensions | ::mlir::DenseIntElementsAttr | 64-bit signless integer elements attribute |
+broadcast_dimensions | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
#### Operands:
@@ -455,7 +455,7 @@ Effects: MemoryEffects::Effect{}
Attribute | MLIR Type | Description |
-window_strides | ::mlir::DenseIntElementsAttr | 64-bit signless integer elements attribute |
+window_strides | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
dimension_numbers | ::mlir::pphlo::ConvDimensionNumbersAttr | Structure of dimension information for conv op |
feature_group_count | ::mlir::IntegerAttr | 64-bit signless integer attribute |
batch_group_count | ::mlir::IntegerAttr | 64-bit signless integer attribute |
@@ -658,7 +658,7 @@ Effects: MemoryEffects::Effect{}
Attribute | MLIR Type | Description |
-slice_sizes | ::mlir::DenseIntElementsAttr | 64-bit signless integer elements attribute |
+slice_sizes | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
#### Operands:
@@ -837,43 +837,6 @@ Effects: MemoryEffects::Effect{}
| :-----: | ----------- |
| `operand` | statically shaped tensor of PPHlo public type or PPHlo secret type values
-### `pphlo.gather` (pphlo::GatherOp)
-
-_Gather operator_
-
-Stitches together several slices of `operand` from offsets specified in
-`start_indices` (each slice at a potentially different runtime offset).
-
-See https://www.tensorflow.org/xla/operation_semantics#gather.
-
-Traits: AlwaysSpeculatableImplTrait
-
-Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)
-
-Effects: MemoryEffects::Effect{}
-
-#### Attributes:
-
-
-Attribute | MLIR Type | Description |
-dimension_numbers | ::mlir::pphlo::GatherDimensionNumbersAttr | Attribute that models the dimension information for gather |
-slice_sizes | ::mlir::DenseIntElementsAttr | 64-bit signless integer elements attribute |
-indices_are_sorted | ::mlir::BoolAttr | bool attribute |
-
-
-#### Operands:
-
-| Operand | Description |
-| :-----: | ----------- |
-| `operand` | statically shaped tensor of PPHlo public type or PPHlo secret type values
-| `start_indices` | statically shaped tensor of public integer type or secret integer type values
-
-#### Results:
-
-| Result | Description |
-| :----: | ----------- |
-«unnamed» | statically shaped tensor of PPHlo public type or PPHlo secret type values
-
### `pphlo.greater_equal` (pphlo::GreaterEqualOp)
_Greater_equal comparison operator_
@@ -1185,8 +1148,8 @@ Effects: MemoryEffects::Effect{}
Attribute | MLIR Type | Description |
-window_dimensions | ::mlir::DenseIntElementsAttr | 64-bit signless integer elements attribute |
-window_strides | ::mlir::DenseIntElementsAttr | 64-bit signless integer elements attribute |
+window_dimensions | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
+window_strides | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
#### Operands:
@@ -1491,7 +1454,7 @@ Traits: RecursiveMemoryEffects, SameVariadicOperandSize, SingleBlock, SingleBloc
Attribute | MLIR Type | Description |
-dimensions | ::mlir::DenseIntElementsAttr | 64-bit signless integer elements attribute |
+dimensions | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
#### Operands:
@@ -1522,9 +1485,9 @@ Traits: RecursiveMemoryEffects, SameVariadicOperandSize, SingleBlock, SingleBloc
Attribute | MLIR Type | Description |
-window_dimensions | ::mlir::DenseIntElementsAttr | 64-bit signless integer elements attribute |
-window_strides | ::mlir::DenseIntElementsAttr | 64-bit signless integer elements attribute |
-window_dilations | ::mlir::DenseIntElementsAttr | 64-bit signless integer elements attribute |
+window_dimensions | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
+window_strides | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
+window_dilations | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
#### Operands:
@@ -1630,7 +1593,7 @@ Effects: MemoryEffects::Effect{}
Attribute | MLIR Type | Description |
-dimensions | ::mlir::DenseIntElementsAttr | 64-bit signless integer elements attribute |
+dimensions | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
#### Operands:
@@ -1745,8 +1708,8 @@ Traits: RecursiveMemoryEffects
Attribute | MLIR Type | Description |
-window_dimensions | ::mlir::DenseIntElementsAttr | 64-bit signless integer elements attribute |
-window_strides | ::mlir::DenseIntElementsAttr | 64-bit signless integer elements attribute |
+window_dimensions | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
+window_strides | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
#### Operands:
@@ -1986,9 +1949,9 @@ Effects: MemoryEffects::Effect{}
Attribute | MLIR Type | Description |
-start_indices | ::mlir::DenseIntElementsAttr | 64-bit signless integer elements attribute |
-limit_indices | ::mlir::DenseIntElementsAttr | 64-bit signless integer elements attribute |
-strides | ::mlir::DenseIntElementsAttr | 64-bit signless integer elements attribute |
+start_indices | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
+limit_indices | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
+strides | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
#### Operands:
@@ -2136,7 +2099,7 @@ Effects: MemoryEffects::Effect{}
Attribute | MLIR Type | Description |
-permutation | ::mlir::DenseIntElementsAttr | 64-bit signless integer elements attribute |
+permutation | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
#### Operands:
diff --git a/examples/python/ml/BUILD.bazel b/examples/python/ml/BUILD.bazel
index 065208db..264e7537 100644
--- a/examples/python/ml/BUILD.bazel
+++ b/examples/python/ml/BUILD.bazel
@@ -40,7 +40,8 @@ py_test(
"//examples/python/ml/stax_mnist_classifier",
"//examples/python/ml/stax_nn",
"//examples/python/ml/tf_experiment",
- "//examples/python/ml/torch_experiment",
+ "//examples/python/ml/torch_lr_experiment",
+ "//examples/python/ml/torch_resnet_experiment",
"//spu/utils:distributed",
],
)
diff --git a/examples/python/ml/README.md b/examples/python/ml/README.md
index feb4f10c..20124747 100644
--- a/examples/python/ml/README.md
+++ b/examples/python/ml/README.md
@@ -27,4 +27,5 @@ library, and private inference of a pre-trained ResNet-50 model based on [Micros
* [jraph_gnn](jraph_gnn/): Private training of a [graph convolutional network](https://arxiv.org/abs/1609.02907) model with
[Jraph](https://github.com/deepmind/jraph).
* [tf_experiment](tf_experiment/): Private training of a logistic regression model with TensorFlow (**experimental**).
-* [torch_experiment](torch_experiment/): Private inference of a linear regression model with PyTorch (**experimental**).
+* [torch_lr_experiment](torch_lr_experiment/): Private inference of a logistic regression model with PyTorch (**experimental**).
+* [torch_resnet_experiment](torch_resnet_experiment/): Private inference of a [ResNet](https://arxiv.org/abs/1512.03385) model with PyTorch (**experimental**).
diff --git a/examples/python/ml/haiku_lstm/README.md b/examples/python/ml/haiku_lstm/README.md
index 4a150d4b..97f0990c 100644
--- a/examples/python/ml/haiku_lstm/README.md
+++ b/examples/python/ml/haiku_lstm/README.md
@@ -9,7 +9,7 @@ This example comes from Haiku official github repo:
1. Install dependencies
```sh
- pip install -r requirements.txt
+ pip install -r ../requirements.txt
```
2. Launch SPU backend runtime
diff --git a/examples/python/ml/haiku_lstm/requirements.txt b/examples/python/ml/haiku_lstm/requirements.txt
deleted file mode 100644
index e4ca8fe6..00000000
--- a/examples/python/ml/haiku_lstm/requirements.txt
+++ /dev/null
@@ -1,2 +0,0 @@
-dm-haiku==0.0.10
-plotnine
diff --git a/examples/python/ml/jraph_gnn/README.md b/examples/python/ml/jraph_gnn/README.md
index 0739ddd4..32883888 100644
--- a/examples/python/ml/jraph_gnn/README.md
+++ b/examples/python/ml/jraph_gnn/README.md
@@ -9,7 +9,7 @@ This example comes from Jraph official github repo:
1. Install dependencies
```sh
- pip install -r requirements.txt
+ pip install -r ../requirements.txt
```
2. Set runtime configuration
diff --git a/examples/python/ml/jraph_gnn/requirements.txt b/examples/python/ml/jraph_gnn/requirements.txt
deleted file mode 100644
index 01b28ae8..00000000
--- a/examples/python/ml/jraph_gnn/requirements.txt
+++ /dev/null
@@ -1,2 +0,0 @@
-dm-haiku
-jraph
\ No newline at end of file
diff --git a/examples/python/ml/ml_test.py b/examples/python/ml/ml_test.py
index 7526b2b9..cedd148e 100644
--- a/examples/python/ml/ml_test.py
+++ b/examples/python/ml/ml_test.py
@@ -213,20 +213,27 @@ def test_tf_experiment(self):
score = tf_experiment.run_fit_manual_grad_spu()
self.assertGreater(score, 0.9)
- def test_torch_experiment(self):
- from examples.python.ml.torch_experiment import torch_experiment
+ def test_torch_lr_experiment(self):
+ from examples.python.ml.torch_lr_experiment import torch_lr_experiment
- model = torch_experiment.LinearRegression()
- torch_experiment.train(model)
- score = torch_experiment.run_inference_on_spu(model)
+ model = torch_lr_experiment.LinearRegression()
+ torch_lr_experiment.train(model)
+ score = torch_lr_experiment.run_inference_on_spu(model)
self.assertGreater(score, 0.9)
+ def test_torch_resnet_experiment(self):
+ from examples.python.ml.torch_resnet_experiment import torch_resnet_experiment
+
+ model = torch_resnet_experiment.resnet
+ image = torch_resnet_experiment.input_batch
+ label = torch_resnet_experiment.run_inference_on_spu(model, image)
+ self.assertEqual(label, 258)
+
def test_save_and_load_model(self):
from examples.python.ml.jax_lr import jax_lr
score = jax_lr.save_and_load_model()
self.assertGreater(score, 0.9)
- pass
def suite():
@@ -246,7 +253,8 @@ def suite():
suite.addTest(UnitTests('test_save_and_load_model'))
# should put JAX tests above
suite.addTest(UnitTests('test_tf_experiment'))
- suite.addTest(UnitTests('test_torch_experiment'))
+ suite.addTest(UnitTests('test_torch_lr_experiment'))
+ # suite.addTest(UnitTests('test_torch_resnet_experiment'))
return suite
diff --git a/examples/python/ml/requirements.txt b/examples/python/ml/requirements.txt
new file mode 100644
index 00000000..133e8a38
--- /dev/null
+++ b/examples/python/ml/requirements.txt
@@ -0,0 +1,7 @@
+dm-haiku==0.0.10
+plotnine
+jraph
+optax==0.1.7
+torch==2.1.0
+torch_xla==2.1.0
+torchvision
\ No newline at end of file
diff --git a/examples/python/ml/torch_experiment/README.md b/examples/python/ml/torch_experiment/README.md
deleted file mode 100644
index dbc728c2..00000000
--- a/examples/python/ml/torch_experiment/README.md
+++ /dev/null
@@ -1,24 +0,0 @@
-# Torch Example
-
-This example demonstrates how to use SPU to make inferences on a linear regression model privately with PyTorch.
-
-The model is trained with plaintext publicly. Currently, SPU's support of PyTorch is **experimental** and we only tested on Linux.
-
-1. Install a third-party dependency [Torch-MLIR](https://github.com/llvm/torch-mlir).
-
- ```sh
- pip install https://github.com/llvm/torch-mlir/releases/download/snapshot-20220830.581/torch-1.13.0.dev20220830+cpu-cp38-cp38-linux_x86_64.whl
- pip install https://github.com/llvm/torch-mlir/releases/download/snapshot-20220830.581/torch_mlir-20220830.581-cp38-cp38-linux_x86_64.whl
- ```
-
-2. Launch SPU backend runtime
-
- ```sh
- bazel run -c opt //examples/python/utils:nodectl -- up
- ```
-
-3. Run `torch_experiment` example
-
- ```sh
- bazel run -c opt //examples/python/ml/torch_experiment
- ```
diff --git a/examples/python/ml/torch_experiment/BUILD.bazel b/examples/python/ml/torch_lr_experiment/BUILD.bazel
similarity index 87%
rename from examples/python/ml/torch_experiment/BUILD.bazel
rename to examples/python/ml/torch_lr_experiment/BUILD.bazel
index 715ed42d..36cdcce1 100644
--- a/examples/python/ml/torch_experiment/BUILD.bazel
+++ b/examples/python/ml/torch_lr_experiment/BUILD.bazel
@@ -1,4 +1,4 @@
-# Copyright 2023 Ant Group Co., Ltd.
+# Copyright 2024 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -17,8 +17,8 @@ load("@rules_python//python:defs.bzl", "py_binary")
package(default_visibility = ["//visibility:public"])
py_binary(
- name = "torch_experiment",
- srcs = ["torch_experiment.py"],
+ name = "torch_lr_experiment",
+ srcs = ["torch_lr_experiment.py"],
data = [
"//examples/python/conf",
],
diff --git a/examples/python/ml/torch_lr_experiment/README.md b/examples/python/ml/torch_lr_experiment/README.md
new file mode 100644
index 00000000..6bcc2a82
--- /dev/null
+++ b/examples/python/ml/torch_lr_experiment/README.md
@@ -0,0 +1,23 @@
+# Torch Example
+
+This example demonstrates how to use SPU to make private inferences on PyTorch models.
+
+**Note**: Currently, SPU's support of PyTorch is **experimental**.
+
+1. Install a third-party dependency [PyTorch/XLA](https://github.com/pytorch/xla).
+
+ ```sh
+ pip install torch~=2.1.0 torch_xla[tpu]~=2.1.0 -f https://storage.googleapis.com/libtpu-releases/index.html
+ ```
+
+2. Launch SPU backend runtime
+
+ ```sh
+ bazel run -c opt //examples/python/utils:nodectl -- up
+ ```
+
+3. Run `torch_lr_experiment` example
+
+ ```sh
+ bazel run -c opt //examples/python/ml/torch_lr_experiment
+ ```
diff --git a/examples/python/ml/torch_experiment/torch_experiment.py b/examples/python/ml/torch_lr_experiment/torch_lr_experiment.py
similarity index 66%
rename from examples/python/ml/torch_experiment/torch_experiment.py
rename to examples/python/ml/torch_lr_experiment/torch_lr_experiment.py
index 8cb502a7..d06ccd56 100644
--- a/examples/python/ml/torch_experiment/torch_experiment.py
+++ b/examples/python/ml/torch_lr_experiment/torch_lr_experiment.py
@@ -22,18 +22,12 @@
import spu.utils.distributed as ppd
-# This is an experimental example to show legacy pytorch program could be run
-# by SPU. Currently we rely on torch-mlir to convert torch code into MLIR
-# (specifically MHLO) which is then consumed by SPU. To run this example,
-# torch-mlir python package should be installed. This example here trains a
-# linear regression model in plaintext and makes private inferences with joint
-# features.
# Start nodes.
# > bazel run -c opt //examples/python/utils:nodectl -- up
#
# Run this example script.
-# > bazel run -c opt //examples/python/ml/torch_experiment:torch_experiment
+# > bazel run -c opt //examples/python/ml/torch_lr_experiment:torch_lr_experiment
class LinearRegression(torch.nn.Module):
@@ -41,17 +35,15 @@ def __init__(self):
super(LinearRegression, self).__init__()
self.linear = torch.nn.Linear(30, 1)
- def forward(self, x1, x2):
- y_pred = self.linear(torch.cat((x1, x2), 1))
+ def forward(self, x):
+ y_pred = self.linear(x)
return y_pred
def train(model, n_epochs=500, lr=0.01):
print('Train model with plaintext features\n------\n')
x, y = breast_cancer()
- x1, x2 = x[:, :15], x[:, 15:]
- x1 = torch.Tensor(x1)
- x2 = torch.Tensor(x2)
+ x = torch.Tensor(x)
y = torch.Tensor(y).view(-1, 1)
criterion = torch.nn.BCEWithLogitsLoss()
@@ -59,7 +51,7 @@ def train(model, n_epochs=500, lr=0.01):
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
for _ in range(n_epochs):
- pred_y = model(x1, x2)
+ pred_y = model(x)
loss = criterion(pred_y, y)
optimizer.zero_grad()
loss.backward()
@@ -70,7 +62,6 @@ def train(model, n_epochs=500, lr=0.01):
# prepare test datasets
def breast_cancer(
- col_slicer=slice(None, None, None),
train: bool = True,
*,
normalize: bool = True,
@@ -96,7 +87,6 @@ def breast_cancer(
else:
x_ = x_test
y_ = y_test
- x_ = x_[:, col_slicer]
return x_.astype(dtype=np.float32), y_.astype(dtype=np.float32)
@@ -105,10 +95,10 @@ def breast_cancer(
def run_inference_on_cpu(model):
print('Run on CPU\n------\n')
- x_test, y_test = breast_cancer(slice(None, None, None), False)
- x1, x2 = torch.Tensor(x_test[:, :15]), torch.Tensor(x_test[:, 15:])
+ x_test, y_test = breast_cancer(False)
+ x = torch.Tensor(x_test)
start_ts = time.time()
- y_pred = model.forward(x1, x2).cpu().detach().numpy()
+ y_pred = model(x).cpu().detach().numpy()
end_ts = time.time()
auc = metrics.roc_auc_score(y_test, y_pred)
print(f"AUC(cpu)={auc}, time={end_ts-start_ts}\n------\n")
@@ -123,35 +113,36 @@ def run_inference_on_cpu(model):
ppd.init(conf["nodes"], conf["devices"], framework=ppd.Framework.EXP_TORCH)
+from collections import OrderedDict
+from jax.tree_util import tree_map
+
def run_inference_on_spu(model):
print('Run on SPU\n------\n')
- x1, _ = ppd.device("P1")(breast_cancer)(slice(None, 15), False)
- x2, _ = ppd.device("P2")(breast_cancer)(slice(15, None), False)
+
+ # load parameters and buffers on P1
+ params_buffers = OrderedDict()
+ for k, v in model.named_parameters():
+ params_buffers[k] = v
+ for k, v in model.named_buffers():
+ params_buffers[k] = v
+ params = ppd.device("P1")(
+ lambda input: tree_map(lambda x: x.detach().numpy(), input)
+ )(params_buffers)
+
+ # load inputs on P2
+ x, _ = ppd.device("P2")(breast_cancer)(False)
+
start_ts = time.time()
- y_pred_ciphertext = ppd.device('SPU')(model)(x1, x2)
+ y_pred_ciphertext = ppd.device('SPU')(model)(params, x)
end_ts = time.time()
y_pred_plaintext = ppd.get(y_pred_ciphertext)
- _, y_test = breast_cancer(slice(None, None, None), False)
+ _, y_test = breast_cancer(False)
auc = metrics.roc_auc_score(y_test, y_pred_plaintext)
print(f"AUC(cpu)={auc}, time={end_ts-start_ts}\n------\n")
return auc
-def compile_torch_to_mhlo(model):
- print('Compile torch program to mhlo test\n------\n')
- x_test, _ = breast_cancer(slice(None, None, None), False)
- x1, x2 = torch.Tensor(x_test[:, :15]), torch.Tensor(x_test[:, 15:])
- import torch_mlir
-
- module = torch_mlir.compile(
- model,
- [x1, x2],
- output_type=torch_mlir.OutputType.MHLO,
- )
- print(f"MHLO={module}\n------\n")
-
-
if __name__ == '__main__':
# For reproducibility
torch.manual_seed(0)
@@ -159,8 +150,7 @@ def compile_torch_to_mhlo(model):
model = LinearRegression()
# Train model with plaintext features
train(model)
- # Torch-mlho conversion test
- compile_torch_to_mhlo(model)
+ model.eval()
# Native torch inference
run_inference_on_cpu(model)
# SPU inference
diff --git a/examples/python/ml/torch_resnet_experiment/BUILD.bazel b/examples/python/ml/torch_resnet_experiment/BUILD.bazel
new file mode 100644
index 00000000..91d89e45
--- /dev/null
+++ b/examples/python/ml/torch_resnet_experiment/BUILD.bazel
@@ -0,0 +1,28 @@
+# Copyright 2024 Ant Group Co., Ltd.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+load("@rules_python//python:defs.bzl", "py_binary")
+
+package(default_visibility = ["//visibility:public"])
+
+py_binary(
+ name = "torch_resnet_experiment",
+ srcs = ["torch_resnet_experiment.py"],
+ data = [
+ "//examples/python/conf",
+ ],
+ deps = [
+ "//spu/utils:distributed",
+ ],
+)
diff --git a/examples/python/ml/torch_resnet_experiment/README.md b/examples/python/ml/torch_resnet_experiment/README.md
new file mode 100644
index 00000000..ab1da24d
--- /dev/null
+++ b/examples/python/ml/torch_resnet_experiment/README.md
@@ -0,0 +1,24 @@
+# Torch Example
+
+This example demonstrates how to use SPU to make private inferences on PyTorch models.
+
+**Note**: Currently, SPU's support of PyTorch is **experimental**.
+
+1. Install a third-party dependency [PyTorch/XLA](https://github.com/pytorch/xla).
+
+ ```sh
+ pip install torch~=2.1.0 torch_xla[tpu]~=2.1.0 -f https://storage.googleapis.com/libtpu-releases/index.html
+ pip install torchvision
+ ```
+
+2. Launch SPU backend runtime
+
+ ```sh
+ bazel run -c opt //examples/python/utils:nodectl -- up
+ ```
+
+3. Run `torch_resnet_experiment` example
+
+ ```sh
+ bazel run -c opt //examples/python/ml/torch_resnet_experiment
+ ```
diff --git a/examples/python/ml/torch_resnet_experiment/torch_resnet_experiment.py b/examples/python/ml/torch_resnet_experiment/torch_resnet_experiment.py
new file mode 100644
index 00000000..16427ebc
--- /dev/null
+++ b/examples/python/ml/torch_resnet_experiment/torch_resnet_experiment.py
@@ -0,0 +1,102 @@
+# Copyright 2023 Ant Group Co., Ltd.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import json
+import urllib
+from collections import OrderedDict
+
+import torch
+from jax.tree_util import tree_map
+from PIL import Image
+from torchvision import transforms
+from torchvision.models import ResNet50_Weights, resnet50
+
+import spu.utils.distributed as ppd
+
+# This is an experimental example to show legacy pytorch program could be run
+# by SPU. Currently we rely on torch-xla to convert torch code into MLIR
+# (specifically StableHLO) which is then consumed by SPU. To run this example,
+# torch-xla python package should be installed.
+
+# Start nodes.
+# > bazel run -c opt //examples/python/utils:nodectl -- up
+#
+# Run this example script.
+# > bazel run -c opt //examples/python/ml/torch_resnet_experiment:torch_resnet_experiment
+
+
+parser = argparse.ArgumentParser(description='distributed driver.')
+parser.add_argument("-c", "--config", default="examples/python/conf/3pc.json")
+args = parser.parse_args()
+
+with open(args.config, 'r') as file:
+ conf = json.load(file)
+
+ppd.init(conf["nodes"], conf["devices"], framework=ppd.Framework.EXP_TORCH)
+
+url, filename = (
+ "https://github.com/pytorch/hub/raw/master/images/dog.jpg",
+ "dog.jpg",
+)
+
+urllib.request.urlretrieve(url, filename)
+
+
+input_image = Image.open(filename)
+preprocess = transforms.Compose(
+ [
+ transforms.Resize(256),
+ transforms.CenterCrop(224),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
+ ]
+)
+input_tensor = preprocess(input_image)
+input_batch = input_tensor.unsqueeze(0)
+resnet = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
+resnet.eval()
+
+
+def run_inference_on_cpu(model, image):
+ print('Run on CPU\n------\n')
+ output = model(image)
+ # model predicts one of the 1000 ImageNet classes
+ predicted_label = output.argmax(-1).item()
+ print(f"predicted_label={predicted_label}\n------\n")
+ return predicted_label
+
+
+def run_inference_on_spu(model, image):
+ print('Run on SPU\n------\n')
+ params_buffers = OrderedDict()
+ for k, v in model.named_parameters():
+ params_buffers[k] = v
+ for k, v in model.named_buffers():
+ params_buffers[k] = v
+ params = ppd.device("P1")(
+ lambda input: tree_map(lambda x: x.detach().numpy(), input)
+ )(params_buffers)
+ image_hat = ppd.device("P2")(lambda x: x.detach().numpy())(image)
+ res = ppd.device("SPU")(model)(params, image_hat)
+ predicted_label = ppd.get(res).argmax(-1).item()
+ print(f"predicted_label={predicted_label}\n------\n")
+ return predicted_label
+
+
+if __name__ == '__main__':
+ torch.manual_seed(0)
+ run_inference_on_cpu(resnet, input_batch)
+ run_inference_on_spu(resnet, input_batch)
diff --git a/libspu/compiler/core/core.cc b/libspu/compiler/core/core.cc
index e5be8f8c..08398f8b 100644
--- a/libspu/compiler/core/core.cc
+++ b/libspu/compiler/core/core.cc
@@ -62,8 +62,6 @@ void Core::buildPipeline(mlir::PassManager *pm) {
optPM.addPass(mlir::pphlo::createRewriteDivSqrtPatterns());
}
- optPM.addPass(mlir::pphlo::createExpandSecretGatherPass());
-
if (options.enable_optimize_denominator_with_broadcast()) {
optPM.addPass(mlir::pphlo::createOptimizeDenominatorWithBroadcast());
}
diff --git a/libspu/compiler/front_end/BUILD.bazel b/libspu/compiler/front_end/BUILD.bazel
index 8a6e1f97..399e30b8 100644
--- a/libspu/compiler/front_end/BUILD.bazel
+++ b/libspu/compiler/front_end/BUILD.bazel
@@ -76,5 +76,6 @@ spu_cc_library(
"@llvm-project//mlir:FuncExtensions",
"@llvm-project//mlir:Parser",
"@xla//xla/mlir_hlo:mhlo_passes",
+ "@xla//xla/translate/mhlo_to_hlo:translate",
],
)
diff --git a/libspu/compiler/front_end/fe.cc b/libspu/compiler/front_end/fe.cc
index 5d89e7c3..f1185296 100644
--- a/libspu/compiler/front_end/fe.cc
+++ b/libspu/compiler/front_end/fe.cc
@@ -24,6 +24,7 @@
#include "stablehlo/dialect/StablehloOps.h"
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
#include "xla/mlir_hlo/mhlo/transforms/passes.h"
+#include "xla/translate/mhlo_to_hlo/translate.h"
#include "libspu/compiler/common/compilation_context.h"
#include "libspu/compiler/front_end/hlo_importer.h"
@@ -44,21 +45,33 @@ FE::FE(CompilationContext *ctx) : ctx_(ctx) {
}
mlir::OwningOpRef FE::doit(const CompilationSource &source) {
+ HloImporter importer(ctx_);
mlir::OwningOpRef module;
- switch (source.ir_type()) {
- case spu::SourceIRType::XLA: {
- HloImporter importer(ctx_);
- module = importer.parseXlaModuleFromString(source.ir_txt());
- break;
- }
- case spu::SourceIRType::MLIR_HLO: {
+
+ if (source.ir_type() == spu::SourceIRType::STABLEHLO) {
module = mlir::parseSourceString(source.ir_txt(),
ctx_->getMLIRContext());
- break;
- }
- default: {
- SPU_THROW("Unsupported input IR type = {}", source.ir_type());
- }
+
+ // Convert stablehlo to mhlo first
+ mlir::PassManager pm(ctx_->getMLIRContext());
+ pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass());
+ if (pm.run(module.get()).failed()) {
+ SPU_THROW("Failed to legalized stablehlo to mhlo");
+ }
+
+ // Convert back to XLA, SPU still relies on XLA to eliminate ops like
+ // batch-normal-inference
+ std::string xla_text;
+ llvm::raw_string_ostream out(xla_text);
+ if (!mlir::failed(xla::MlirHloToHloTranslateFunction(module.get(), out,
+ true, true))) {
+ out.flush();
+ module = importer.parseXlaModuleFromString(xla_text);
+ }
+ } else if (source.ir_type() == spu::SourceIRType::XLA) {
+ module = importer.parseXlaModuleFromString(source.ir_txt());
+ } else {
+ SPU_THROW("Unhandled IR type = {}", source.ir_type());
}
std::string input_vis_str;
diff --git a/libspu/compiler/front_end/hlo_importer.cc b/libspu/compiler/front_end/hlo_importer.cc
index 149f5a4b..80ba9aba 100644
--- a/libspu/compiler/front_end/hlo_importer.cc
+++ b/libspu/compiler/front_end/hlo_importer.cc
@@ -127,7 +127,7 @@ void runHloPasses(xla::HloModule *module) {
/*allow_mixed_precision=*/false);
pipeline.AddPass();
- pipeline.AddPass(GatherExpander::kEliminateSimpleGathers);
+ pipeline.AddPass(GatherExpander::kEliminateAllGathers);
pipeline.AddPass(ScatterExpander::kEliminateAllScatters);
pipeline.AddPass(options);
pipeline.AddPass();
@@ -163,7 +163,10 @@ HloImporter::parseXlaModuleFromString(const std::string &content) {
// If parse as HloModuleProto fails, try HloProto.
xla::HloProto hlo_proto;
if (!hlo_proto.ParseFromString(content)) {
- SPU_THROW("Failed to parse hlo module from string");
+ // Try human-readable format
+ if (!google::protobuf::TextFormat::ParseFromString(content, &hlo_proto)) {
+ SPU_THROW("Failed to parse hlo module from string {}", content);
+ }
}
hlo_module = hlo_proto.hlo_module();
}
diff --git a/libspu/compiler/passes/BUILD.bazel b/libspu/compiler/passes/BUILD.bazel
index c470e86e..ef5ca7d3 100644
--- a/libspu/compiler/passes/BUILD.bazel
+++ b/libspu/compiler/passes/BUILD.bazel
@@ -197,18 +197,6 @@ spu_cc_library(
],
)
-spu_cc_library(
- name = "expand_secret_gather",
- srcs = ["expand_secret_gather.cc"],
- hdrs = ["passes.h"],
- deps = [
- ":pass_details",
- "//libspu/dialect:pphlo_dialect",
- "@llvm-project//mlir:IR",
- "@llvm-project//mlir:TransformUtils",
- ],
-)
-
spu_cc_library(
name = "rewrite_div_sqrt_patterns",
srcs = ["rewrite_div_sqrt_patterns.cc"],
@@ -276,7 +264,6 @@ spu_cc_library(
":convert_push_down",
":decompose_comparison",
":decompose_minmax",
- ":expand_secret_gather",
":hlo_legalize_to_pphlo",
":insert_deallocation",
":lower_conversion_cast",
diff --git a/libspu/compiler/passes/expand_secret_gather.cc b/libspu/compiler/passes/expand_secret_gather.cc
deleted file mode 100644
index 10117f31..00000000
--- a/libspu/compiler/passes/expand_secret_gather.cc
+++ /dev/null
@@ -1,641 +0,0 @@
-// Copyright 2023 Ant Group Co., Ltd.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include
-#include
-
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-
-#include "libspu/compiler/passes/pass_details.h"
-#include "libspu/dialect/pphlo_ops.h"
-
-namespace mlir::pphlo {
-
-namespace {
-
-bool GatherIsBroadcast(GatherOp &op) {
- auto gather_slice_size = op.getSliceSizes();
- auto op_shape = op.getOperand().getType().getShape();
- return (gather_slice_size.size() == op_shape.size()) &&
- (std::equal(gather_slice_size.begin(), gather_slice_size.end(),
- op_shape.begin()));
-}
-
-std::vector DeleteDimensions(llvm::ArrayRef dims_to_delete,
- llvm::ArrayRef shape) {
- std::unordered_set ordered_dims_to_delete(dims_to_delete.begin(),
- dims_to_delete.end());
-
- std::vector result;
- result.reserve(shape.size() - ordered_dims_to_delete.size());
-
- for (size_t idx = 0; idx < shape.size(); ++idx) {
- if (ordered_dims_to_delete.count(idx) != 0) {
- continue;
- }
- result.emplace_back(idx);
- }
- return result;
-}
-
-// Computes how many trips a loop implementing this gather op would take.
-int64_t GatherLoopTripCount(GatherOp op) {
- auto start_indices = op.getStartIndices();
- const auto start_indices_shape = start_indices.getType().getShape();
- const auto &dim_numbers = op.getDimensionNumbers();
-
- int64_t trip_count = 1;
- for (int64_t i = 0, e = start_indices_shape.size(); i < e; i++) {
- if (i != dim_numbers.getIndexVectorDim()) {
- trip_count *= start_indices_shape[i];
- }
- }
- return trip_count;
-}
-
-llvm::SmallVector
-ComputePermutedShape(llvm::ArrayRef shape,
- llvm::ArrayRef permutation) {
- llvm::SmallVector result_shape;
- for (auto dim : permutation) {
- result_shape.emplace_back(shape[dim]);
- }
- return result_shape;
-}
-
-TypedValue
-TransposeIndexVectorDimToLast(TypedValue &start_indices,
- int64_t index_vector_dim) {
- const auto start_indices_shape = start_indices.getType().getShape();
-
- if (static_cast(start_indices_shape.size()) == index_vector_dim) {
- return start_indices;
- }
-
- if (index_vector_dim ==
- static_cast(start_indices_shape.size() - 1)) {
- return start_indices;
- }
-
- std::vector permutation;
- permutation.reserve(start_indices_shape.size());
- for (int64_t i = 0, e = start_indices_shape.size(); i < e; i++) {
- if (i != index_vector_dim) {
- permutation.emplace_back(i);
- }
- }
- permutation.emplace_back(index_vector_dim);
-
- auto result_shape = ComputePermutedShape(start_indices_shape, permutation);
-
- OpBuilder builder(start_indices.getContext());
- if (auto *ip = start_indices.getDefiningOp()) {
- builder.setInsertionPointAfter(ip);
- } else {
- builder.setInsertionPointToStart(start_indices.getParentBlock());
- }
-
- auto transpose = builder.create(
- start_indices.getLoc(),
- RankedTensorType::get(result_shape,
- start_indices.getType().getElementType()),
- start_indices, permutation);
-
- return transpose.getResult();
-}
-
-TypedValue
-PrependDegenerateDims(TypedValue operand, int64_t n) {
- SPU_ENFORCE(n > 0);
- std::vector new_shape_dims;
- const auto operand_shape = operand.getType().getShape();
- new_shape_dims.reserve(n + operand_shape.size());
- new_shape_dims.insert(new_shape_dims.begin(), n, 1);
- std::copy(operand_shape.begin(), operand_shape.end(),
- std::back_inserter(new_shape_dims));
-
- OpBuilder builder(operand.getContext());
- if (auto *ip = operand.getDefiningOp()) {
- builder.setInsertionPointAfter(ip);
- } else {
- builder.setInsertionPointToStart(operand.getParentBlock());
- }
-
- auto reshape = builder.create(
- operand.getLoc(),
- RankedTensorType::get(new_shape_dims, operand.getType().getElementType()),
- operand);
-
- return reshape.getResult();
-}
-
-TypedValue
-CollapseFirstNDims(TypedValue operand, int64_t n) {
- SPU_ENFORCE(n > 0);
-
- const auto operand_shape = operand.getType().getShape();
- SPU_ENFORCE((int64_t)operand_shape.size() >= n);
-
- int64_t new_shape_leading_bound = 1;
- for (int64_t i = 0; i < n; i++) {
- new_shape_leading_bound *= operand_shape[i];
- }
-
- std::vector new_shape_dims;
- new_shape_dims.reserve(operand_shape.size() - n + 1);
- new_shape_dims.push_back(new_shape_leading_bound);
-
- std::copy(operand_shape.begin() + n, operand_shape.end(),
- std::back_inserter(new_shape_dims));
-
- auto output_type =
- RankedTensorType::get(new_shape_dims, operand.getType().getElementType());
-
- OpBuilder builder(operand.getContext());
- if (auto *ip = operand.getDefiningOp()) {
- builder.setInsertionPointAfter(ip);
- } else {
- builder.setInsertionPointToStart(operand.getParentBlock());
- }
-
- auto reshape =
- builder.create(operand.getLoc(), output_type, operand);
-
- return reshape.getResult();
-}
-
-// Canonicalizes the start_indices tensors so that we only have deal with some
-// specific cases in the while loop that does the heavy lifting.
-//
-// See the "High Level Algorithm" section for a broader picture.
-TypedValue
-CanonicalizeGatherIndices(TypedValue &start_indices,
- int64_t index_vector_dim) {
- // Transpose the non-index-vector dimensions to the front.
- auto transposed_start_indices =
- TransposeIndexVectorDimToLast(start_indices, index_vector_dim);
- bool indices_are_scalar =
- index_vector_dim ==
- static_cast(start_indices.getType().getShape().size());
-
- // The number of dimensions in start_indices that are index dimensions.
- const int64_t index_dims_in_start_indices = indices_are_scalar ? 0 : 1;
-
- // If there is only one index (i.e. start_indices has rank 1 and this gather
- // is really just a dynamic slice) add a leading degenerate dimension for
- // uniformity. Otherwise create a "collapsed" leading dimension that subsumes
- // all of the non-index-vector dimensions.
- const auto shape = transposed_start_indices.getType().getShape();
- if (static_cast(shape.size()) == index_dims_in_start_indices) {
- return PrependDegenerateDims(transposed_start_indices, 1);
- } else {
- // Collapse all but the dimensions (0 or 1) in start_indices containing the
- // index vectors.
- return CollapseFirstNDims(transposed_start_indices,
- shape.size() - index_dims_in_start_indices);
- }
-}
-
-TypedValue CreateGatherLoopAccumulatorInitValue(
- GatherOp op, Type element_type, llvm::ArrayRef slice_sizes,
- int64_t gather_loop_trip_count,
- const GatherDimensionNumbersAttr &dim_numbers) {
- std::vector accumulator_state_shape_dims;
- accumulator_state_shape_dims.reserve(1 + slice_sizes.size());
- accumulator_state_shape_dims.push_back(gather_loop_trip_count);
- for (int64_t i = 0; i < static_cast(slice_sizes.size()); i++) {
- if (!std::binary_search(dim_numbers.getCollapsedSliceDims().begin(),
- dim_numbers.getCollapsedSliceDims().end(), i)) {
- accumulator_state_shape_dims.emplace_back(slice_sizes[i]);
- }
- }
-
- OpBuilder builder(op);
- TypeTools type_tools;
-
- auto express_type = type_tools.getExpressedType(element_type);
- auto shaped_type =
- RankedTensorType::get(accumulator_state_shape_dims, express_type);
- auto zero_attr = builder.getZeroAttr(shaped_type);
-
- if (zero_attr == nullptr && express_type.isa()) {
- std::complex zero = {APFloat(0.0F), APFloat(0.0F)};
- zero_attr = DenseElementsAttr::get(shaped_type,
- std::vector>(
- shaped_type.getNumElements(), zero));
- }
-
- auto c = builder.create(op->getLoc(), zero_attr);
-
- if (type_tools.getTypeVisibility(element_type) != Visibility::VIS_PUBLIC) {
- auto convert = builder.create(
- op.getLoc(),
- RankedTensorType::get(accumulator_state_shape_dims, element_type),
- c.getResult());
- return convert.getResult();
- } else {
- return c.getResult();
- }
-}
-
-TypedValue
-ExpandFirstDimIntoNDims(TypedValue operand,
- llvm::ArrayRef expanded_dims) {
- SPU_ENFORCE_GT(operand.getType().getShape().size(), size_t(0));
- SPU_ENFORCE_EQ(operand.getType().getShape()[0],
- std::accumulate(expanded_dims.begin(), expanded_dims.end(), 1,
- std::multiplies()));
-
- std::vector expanded_shape_dim_bounds;
- expanded_shape_dim_bounds.reserve(expanded_dims.size() +
- operand.getType().getShape().size() - 1);
- std::copy(expanded_dims.begin(), expanded_dims.end(),
- std::back_inserter(expanded_shape_dim_bounds));
- std::copy(operand.getType().getShape().begin() + 1,
- operand.getType().getShape().end(),
- std::back_inserter(expanded_shape_dim_bounds));
-
- auto result_type = RankedTensorType::get(expanded_shape_dim_bounds,
- operand.getType().getElementType());
-
- OpBuilder builder(operand.getContext());
- if (auto *ip = operand.getDefiningOp()) {
- builder.setInsertionPointAfter(ip);
- } else {
- builder.setInsertionPointToStart(operand.getParentBlock());
- }
- auto reshaped =
- builder.create(operand.getLoc(), result_type, operand);
- return reshaped.getResult();
-}
-
-TypedValue
-ElideDegenerateDims(OpBuilder *builder, TypedValue operand,
- absl::Span dims_to_elide) {
- std::unordered_set dims_to_elide_set(dims_to_elide.begin(),
- dims_to_elide.end());
- std::vector new_shape;
- for (size_t idx = 0; idx < operand.getType().getShape().size(); ++idx) {
- if (dims_to_elide_set.count(idx) > 0) {
- continue;
- }
- new_shape.emplace_back(operand.getType().getShape()[idx]);
- }
-
- auto reshape = builder->create(
- operand.getLoc(),
- RankedTensorType::get(new_shape, operand.getType().getElementType()),
- operand);
- return reshape.getResult();
-}
-
-// Expands out or contracts away the gather dimensions in the accumulator
-// produced by the while loop.
-TypedValue AdjustBatchDimsInAccumulator(
- OpBuilder *builder, llvm::ArrayRef start_indices_shape,
- TypedValue accumulator, int64_t index_vector_dim) {
- std::vector batch_dim_bounds;
- batch_dim_bounds.reserve(start_indices_shape.size());
- for (int64_t i = 0, e = start_indices_shape.size(); i < e; i++) {
- if (i != index_vector_dim) {
- batch_dim_bounds.push_back(start_indices_shape[i]);
- }
- }
-
- if (batch_dim_bounds.empty()) {
- // If batch_dim_bounds is empty we must be lowering a (effectively)
- // dynamic-slice. In that case, there is a leading degenerate gather
- // dimension that we added to make this special case play well with the
- // general while loop which we need to remove now.
- return ElideDegenerateDims(builder, accumulator, {0});
- }
-
- return ExpandFirstDimIntoNDims(accumulator, batch_dim_bounds);
-}
-
-void BuildWhileCondition(Region &cond, Value /*counter*/,
- Value /*canonical_start_indices*/,
- Value /*accumulator_init*/, Value loop_upper_bound) {
- OpBuilder builder(cond);
- TypeTools type_tool;
-
- auto lt = builder.create(
- cond.getLoc(),
- RankedTensorType::get(
- {}, type_tool.getTypeWithVisibility(builder.getI1Type(),
- Visibility::VIS_PUBLIC)),
- cond.getArgument(0), loop_upper_bound);
-
- builder.create(cond.getLoc(), ValueRange{lt.getResult()});
-}
-
-int64_t FindIndex(llvm::ArrayRef c, int64_t value) {
- const auto *it = std::find(c.begin(), c.end(), value);
- return std::distance(c.begin(), it);
-}
-
-// Expand an index vector from the start_indices tensor into a vector that can
-// be used to dynamic-slice out of the gather operand.
-llvm::SmallVector ExpandIndexVectorIntoOperandSpace(
- OpBuilder *builder, TypedValue index_vector,
- const GatherDimensionNumbersAttr &dim_numbers, int64_t operand_rank) {
-
- TypeTools typetool;
- auto index_type =
- typetool.getExpressedType(index_vector.getType().getElementType());
-
- if (operand_rank == 0) {
- // This is Gather from a scalar. So, the index vector in operand space must
- // be a zero-sized vector.
- // return computation->AddInstruction(HloInstruction::CreateConstant(
- // LiteralUtil::CreateFromDimensions(index_shape.element_type(), {0})));
- auto zero_const = builder->create(
- index_vector.getLoc(),
- builder->getZeroAttr(RankedTensorType::get({}, index_type)));
- return {zero_const.getResult()};
- }
-
- auto p_zero_const = builder->create(
- index_vector.getLoc(),
- builder->getZeroAttr(RankedTensorType::get({}, index_type)));
-
- auto zero_const = builder->create(
- index_vector.getLoc(),
- RankedTensorType::get({}, typetool.toMPCType(index_type)),
- p_zero_const);
-
- // We extract out individual components from the smaller index and concatenate
- // them (interspersing zeros as needed) into the larger index.
- llvm::SmallVector expanded_index_components;
-
- for (int64_t i = 0; i < operand_rank; i++) {
- int64_t index_vector_dim_index =
- FindIndex(dim_numbers.getStartIndexMap(), i);
- if (index_vector_dim_index !=
- static_cast(dim_numbers.getStartIndexMap().size())) {
-
- auto component_to_concat = builder->create(
- index_vector.getLoc(),
- RankedTensorType::get({1}, index_vector.getType().getElementType()),
- index_vector,
- DenseI64ArrayAttr::get(builder->getContext(),
- {index_vector_dim_index}),
- DenseI64ArrayAttr::get(builder->getContext(),
- {index_vector_dim_index + 1}),
- DenseI64ArrayAttr::get(builder->getContext(), {1}));
- auto reshaped = builder->create(
- index_vector.getLoc(),
- RankedTensorType::get({}, index_vector.getType().getElementType()),
- component_to_concat);
- expanded_index_components.push_back(reshaped);
- } else {
- expanded_index_components.push_back(zero_const);
- }
- }
-
- return expanded_index_components;
-}
-
-// This generates the body of the while that implements the main data movement
-// behavior of gather using dynamic-slice and dynamic-update-slice.
-void GatherLoopBody(GatherOp gather, Region &body,
- TypedValue operand,
- TypedValue start_indices) {
- OpBuilder builder(body);
-
- auto induction_var = body.getArgument(0);
- auto output_accumulator = body.getArgument(1);
-
- TypeTools typetools;
- auto index_type = typetools.getExpressedType(
- induction_var.getType().dyn_cast().getElementType());
-
- // Increment counter first
- auto const_one = builder.create(
- gather->getLoc(),
- DenseElementsAttr::get(RankedTensorType::get({}, index_type),
- builder.getIntegerAttr(index_type, 1)));
-
- // counter + 1
- auto incremented_counter =
- builder.create(induction_var.getLoc(), induction_var.getType(),
- induction_var, const_one);
-
- const auto &dim_numbers = gather.getDimensionNumbers();
-
- bool has_scalar_indices = start_indices.getType().getShape().size() == 1;
- SPU_ENFORCE_EQ(
- has_scalar_indices,
- dim_numbers.getIndexVectorDim() ==
- (int64_t)gather.getStartIndices().getType().getShape().size());
-
- auto index_zero = builder.create(
- gather->getLoc(),
- builder.getZeroAttr(RankedTensorType::get({}, index_type)));
-
- TypedValue index_vector;
-
- if (has_scalar_indices) {
- // In this case start_indices has rank 1 and induction_var_as_vector (of
- // shape {1}) is an index into this rank 1 tensor.
- auto ds = builder.create(
- gather->getLoc(), start_indices, ValueRange{induction_var},
- DenseI64ArrayAttr::get(builder.getContext(), {1}));
- index_vector = ds.getResult();
- } else {
- // In this case start_indices has rank 2 and induction_var_as_vector (of
- // shape {1}) is an index into just the first dimension of this rank 2
- // tensor.
-
- int64_t index_vector_size = start_indices.getType().getShape()[1];
-
- auto index_vector_2d = builder.create(
- gather->getLoc(), start_indices, ValueRange{induction_var, index_zero},
- DenseI64ArrayAttr::get(builder.getContext(), {1, index_vector_size}));
-
- index_vector = ElideDegenerateDims(&builder, index_vector_2d, {0});
- }
-
- auto gathered_slice_start = ExpandIndexVectorIntoOperandSpace(
- &builder, index_vector, dim_numbers, operand.getType().getShape().size());
-
- auto gathered_slice = builder.create(
- gather->getLoc(), operand, gathered_slice_start, gather.getSliceSizes());
-
- auto gathered_slice_with_dims_collapsed = ElideDegenerateDims(
- &builder, gathered_slice, dim_numbers.getCollapsedSliceDims());
-
- auto gathered_slice_for_update =
- PrependDegenerateDims(gathered_slice_with_dims_collapsed, 1);
-
- SmallVector index_vector_into_accumulator;
- index_vector_into_accumulator.push_back(induction_var);
- for (size_t idx = 0;
- idx < gathered_slice_with_dims_collapsed.getType().getShape().size();
- ++idx) {
- index_vector_into_accumulator.push_back(index_zero);
- }
-
- auto updated_accumulator = builder.create(
- gather->getLoc(), output_accumulator, gathered_slice_for_update,
- index_vector_into_accumulator);
-
- builder.create(
- gather->getLoc(), ValueRange{incremented_counter, updated_accumulator});
-}
-
-struct GatherConverter : public OpRewritePattern {
- explicit GatherConverter(MLIRContext *context) : OpRewritePattern(context) {}
-
- LogicalResult matchAndRewrite(GatherOp op,
- PatternRewriter &rewriter) const override {
-
- TypeTools type_tool;
- if (type_tool.getTypeVisibility(op.getStartIndices().getType()) !=
- Visibility::VIS_SECRET) {
- // Do not expand public gather
- return failure();
- }
-
- OpBuilder builder(op);
-
- // Secret gather
- if (GatherIsBroadcast(op)) {
- // Replace gather with broadcast
- auto broadcast_operand_shape =
- DeleteDimensions(op.getDimensionNumbers().getCollapsedSliceDims(),
- op.getOperand().getType().getShape());
- auto reshaped_type = RankedTensorType::get(
- broadcast_operand_shape, op.getOperand().getType().getElementType());
- auto broadcast_operand = builder.create(
- op->getLoc(), reshaped_type, op.getOperand());
- rewriter.replaceOpWithNewOp(
- op, op->getResults().getType(), broadcast_operand,
- DenseI64ArrayAttr::get(builder.getContext(),
- op.getDimensionNumbers().getOffsetDims()));
- return success();
- }
-
- auto index_type = type_tool.getExpressedType(
- op.getStartIndices().getType().getElementType());
- auto operand = op.getOperand();
- auto start_indices = op.getStartIndices();
- auto output_type = op->getResultTypes()[0].dyn_cast();
- auto output_shape = output_type.getShape();
- int64_t output_rank = output_shape.size();
-
- const auto &dim_numbers = op.getDimensionNumbers();
-
- int64_t gather_loop_trip_count = GatherLoopTripCount(op);
-
- auto canonical_start_indices = CanonicalizeGatherIndices(
- start_indices, dim_numbers.getIndexVectorDim());
-
- SPU_ENFORCE(gather_loop_trip_count ==
- canonical_start_indices.getType().getShape()[0]);
-
- auto accumulator_init = CreateGatherLoopAccumulatorInitValue(
- op, output_type.getElementType(), op.getSliceSizes(),
- gather_loop_trip_count, op.getDimensionNumbers());
-
- auto loopUpperBound = builder.create(
- op->getLoc(),
- DenseElementsAttr::get(
- RankedTensorType::get({}, index_type),
- builder.getIntegerAttr(index_type, gather_loop_trip_count)));
-
- auto counter = builder.create(
- op->getLoc(),
- builder.getZeroAttr(RankedTensorType::get({}, index_type)));
-
- auto loop = builder.create(
- op->getLoc(),
- TypeRange{counter.getResult().getType(), accumulator_init.getType()},
- ValueRange{counter, accumulator_init});
- {
- loop.getCond().push_back(new Block());
- loop.getCond().front().addArguments(
- TypeRange{counter.getType(), accumulator_init.getType()},
- {counter.getLoc(), accumulator_init.getLoc()});
- }
- {
- loop.getBody().push_back(new Block());
-
- loop.getBody().front().addArguments(
- TypeRange{counter.getType(), accumulator_init.getType()},
- {counter.getLoc(), accumulator_init.getLoc()});
- }
- // Generate loop condition
- BuildWhileCondition(loop.getCond(), counter.getResult(),
- canonical_start_indices, accumulator_init,
- loopUpperBound.getResult());
-
- GatherLoopBody(op, loop.getBody(), operand, canonical_start_indices);
-
- OpResult accumulator_result = loop->getResults().back();
-
- auto accumulator_with_batch_dims_decanonicalized =
- AdjustBatchDimsInAccumulator(
- &builder, start_indices.getType().getShape(),
- cast>(accumulator_result),
- dim_numbers.getIndexVectorDim());
-
- std::vector permutation;
- permutation.reserve(output_rank);
-
- int64_t batch_idx_counter = 0;
- int64_t offset_idx_counter =
- output_rank - dim_numbers.getOffsetDims().size();
- for (int64_t i = 0; i < output_rank; i++) {
- bool is_offset_dim =
- std::binary_search(dim_numbers.getOffsetDims().begin(),
- dim_numbers.getOffsetDims().end(), i);
- if (is_offset_dim) {
- permutation.push_back(offset_idx_counter++);
- } else {
- permutation.push_back(batch_idx_counter++);
- }
- }
-
- rewriter.replaceOpWithNewOp(
- op, op.getResult().getType(),
- accumulator_with_batch_dims_decanonicalized,
- DenseI64ArrayAttr::get(builder.getContext(), permutation));
-
- return success();
- }
-};
-
-struct ExpandSecretGather : public ExpandSecretGatherBase {
- void runOnOperation() override {
- RewritePatternSet patterns(&getContext());
- populateOwningPatterns(&patterns, &getContext());
- (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
- }
-
-private:
- static void populateOwningPatterns(RewritePatternSet *patterns,
- MLIRContext *ctx) {
- patterns->insert(ctx);
- }
-};
-} // namespace
-
-std::unique_ptr> createExpandSecretGatherPass() {
- return std::make_unique();
-}
-
-} // namespace mlir::pphlo
diff --git a/libspu/compiler/passes/hlo_legalize_to_pphlo.cc b/libspu/compiler/passes/hlo_legalize_to_pphlo.cc
index c5c78b7f..c55e552c 100644
--- a/libspu/compiler/passes/hlo_legalize_to_pphlo.cc
+++ b/libspu/compiler/passes/hlo_legalize_to_pphlo.cc
@@ -1226,41 +1226,6 @@ class HloToPPHloOpConverter
}
};
-template <>
-class HloToPPHloOpConverter
- : public OpConversionPattern {
-private:
- const ValueVisibilityMap &vis_;
-
-public:
- HloToPPHloOpConverter(TypeConverter &type_converter, MLIRContext *context,
- const ValueVisibilityMap &vis)
- : OpConversionPattern(type_converter, context),
- vis_(vis) {}
-
- LogicalResult
- matchAndRewrite(stablehlo::GatherOp op, stablehlo::GatherOpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- auto old_attr = op.getDimensionNumbers();
- pphlo::GatherDimensionNumbersAttr attr = GatherDimensionNumbersAttr::get(
- op.getContext(), old_attr.getOffsetDims(),
- old_attr.getCollapsedSliceDims(), old_attr.getStartIndexMap(),
- old_attr.getIndexVectorDim());
-
- auto result_vis = vis_.getValueVisibility(op.getResult());
-
- Type resultType = HloToPPHloTypeConverter::getTypeWithVisibility(
- this->getTypeConverter()->convertType(op.getType()), result_vis);
-
- rewriter.replaceOpWithNewOp(
- op, resultType, adaptor.getOperands()[0], adaptor.getOperands()[1],
- attr, ConvertDenseIntElementAttr(op.getSliceSizes()),
- op.getIndicesAreSorted());
-
- return success();
- }
-};
-
template <>
class HloToPPHloOpConverter
: public OpConversionPattern {
@@ -1628,7 +1593,6 @@ struct HloLegalizeToPPHlo
HloToPPHloOpConverter,
HloToPPHloOpConverter,
HloToPPHloOpConverter,
- HloToPPHloOpConverter,
HloToPPHloOpConverter,
HloToPPHloOpConverter,
HloToPPHloOpConverter,
diff --git a/libspu/compiler/passes/map_stablehlo_to_pphlo_op.h b/libspu/compiler/passes/map_stablehlo_to_pphlo_op.h
index 4aa3f8a5..b409051a 100644
--- a/libspu/compiler/passes/map_stablehlo_to_pphlo_op.h
+++ b/libspu/compiler/passes/map_stablehlo_to_pphlo_op.h
@@ -58,7 +58,6 @@ MAP_HLO_TO_PPHLO(DivOp)
MAP_HLO_TO_PPHLO(DotOp)
MAP_HLO_TO_PPHLO(ExpOp)
MAP_HLO_TO_PPHLO(Expm1Op)
-MAP_HLO_TO_PPHLO(GatherOp)
MAP_HLO_TO_PPHLO(IotaOp)
MAP_HLO_TO_PPHLO(FloorOp)
MAP_HLO_TO_PPHLO(LogOp)
diff --git a/libspu/compiler/passes/passes.h b/libspu/compiler/passes/passes.h
index fbda4cfb..16c70cab 100644
--- a/libspu/compiler/passes/passes.h
+++ b/libspu/compiler/passes/passes.h
@@ -62,8 +62,6 @@ std::unique_ptr> createOptimizeSelectPass();
// Optimize sqrt(x) + very_small_const) -> sqrt(x + eps)
std::unique_ptr> createOptimizeSqrtPlusEps();
-std::unique_ptr> createExpandSecretGatherPass();
-
// Rewrite x/sqrt(x+eps) -> x*rsqrt(x+eps)
std::unique_ptr> createRewriteDivSqrtPatterns();
diff --git a/libspu/compiler/passes/passes.td b/libspu/compiler/passes/passes.td
index 281649d4..942d531f 100644
--- a/libspu/compiler/passes/passes.td
+++ b/libspu/compiler/passes/passes.td
@@ -81,12 +81,6 @@ def RewriteDivSqrtPatterns: Pass<"rewrite-div-sqrt-pattern", "func::FuncOp"> {
let dependentDialects = ["pphlo::PPHloDialect"];
}
-def ExpandSecretGather: Pass<"expand-secret-gather", "func::FuncOp"> {
- let summary = "Rewrite Gather with secret indexing to loop with DynamicUpdateSlice";
- let constructor = "createExpandSecretGatherPass()";
- let dependentDialects = ["pphlo::PPHloDialect"];
-}
-
def OptimizeDenominatorWithBcast: Pass<"optimize-denominator-with-broadcast", "func::FuncOp"> {
let summary = "Optimize x/broadcast(y) into x*broadcast(1/y)";
let constructor = "createOptimizeDenominatorWithBroadcast()";
diff --git a/libspu/compiler/passes/rewrite_div_sqrt_patterns.cc b/libspu/compiler/passes/rewrite_div_sqrt_patterns.cc
index a975799e..e4a4ae21 100644
--- a/libspu/compiler/passes/rewrite_div_sqrt_patterns.cc
+++ b/libspu/compiler/passes/rewrite_div_sqrt_patterns.cc
@@ -27,23 +27,48 @@ namespace mlir::pphlo {
namespace {
struct DivRewriter : public OpRewritePattern {
+private:
+ Operation *rewriteSqrtIfPossible(PatternRewriter &rewriter,
+ Operation *op) const {
+ if (op == nullptr || op->getNumOperands() != 1) {
+ return nullptr;
+ }
+
+ if (mlir::isa(op)) {
+ return rewriter.create(op->getLoc(), op->getResultTypes(),
+ op->getOperand(0));
+ }
+
+ if (auto bcastOp = mlir::dyn_cast(op)) {
+ if (auto *inner = rewriteSqrtIfPossible(
+ rewriter, bcastOp.getOperand().getDefiningOp())) {
+ return rewriter.create(
+ op->getLoc(), bcastOp->getResultTypes(), inner->getResult(0),
+ bcastOp.getBroadcastDimensions());
+ }
+ return nullptr;
+ }
+
+ return nullptr;
+ }
+
+public:
explicit DivRewriter(MLIRContext *context)
: OpRewritePattern(context) {}
LogicalResult matchAndRewrite(DivOp op,
PatternRewriter &rewriter) const override {
// Pattern 1:
- // y/sqrt(x + eps)
+ // y/sqrt(x) -> y*rsqrt(x)
auto denominator = op.getRhs();
- if (auto sqrt = denominator.getDefiningOp()) {
- auto newRsqrt = rewriter.create(
- denominator.getLoc(), denominator.getType(), sqrt.getOperand());
+ if (auto *newop =
+ rewriteSqrtIfPossible(rewriter, denominator.getDefiningOp())) {
rewriter.replaceOpWithNewOp(op, op.getType(), op.getLhs(),
- newRsqrt);
+ newop->getResult(0));
return success();
} else {
// Pattern 2:
- // y/(k*sqrt(x + eps)) -> y/k*rsqrt(x+eps)
+ // y/(k*sqrt(x)) -> y/k*rsqrt(x)
if (auto mulOp = denominator.getDefiningOp()) {
auto sqrtOp = mulOp.getRhs().getDefiningOp();
auto k = mulOp.getLhs();
@@ -55,10 +80,10 @@ struct DivRewriter : public OpRewritePattern {
// y/k
auto newDiv = rewriter.create(
op.getLoc(), op->getResultTypes(), op.getLhs(), k);
- // rsqrt(x+eps)
+ // rsqrt(x)
auto newRsqrt = rewriter.create(
op->getLoc(), sqrtOp->getResultTypes(), sqrtOp->getOperand(0));
- // y/k*rsqrt(x+eps)
+ // y/k*rsqrt(x)
rewriter.replaceOpWithNewOp(op, op.getType(), newDiv,
newRsqrt);
return success();
diff --git a/libspu/compiler/passes/visibility_inference.cc b/libspu/compiler/passes/visibility_inference.cc
index 3b6eae6c..b238d34f 100644
--- a/libspu/compiler/passes/visibility_inference.cc
+++ b/libspu/compiler/passes/visibility_inference.cc
@@ -306,14 +306,6 @@ void VisibilityInference::inferOperation(Operation &op) {
value_vis_.setValueVisibility(op.getResult(0), Visibility::VIS_PUBLIC);
} else if (llvm::isa(op)) {
inferSort(op);
- } else if (llvm::isa(op)) {
- // For gather op, if either operand or indices is a secret, result is a
- // secret
- auto operand_vis = value_vis_.getValueVisibility(op.getOperand(0));
- auto indices_vis = value_vis_.getValueVisibility(op.getOperand(1));
- value_vis_.setValueVisibility(
- op.getResult(0),
- TypeTools::inferResultVisibility({operand_vis, indices_vis}));
} else if (llvm::isa(op)) {
inferSelectAndScatter(op);
} else if (llvm::isa(op)) {
diff --git a/libspu/compiler/tests/expand_secret_gather.mlir b/libspu/compiler/tests/expand_secret_gather.mlir
deleted file mode 100644
index 36ac1553..00000000
--- a/libspu/compiler/tests/expand_secret_gather.mlir
+++ /dev/null
@@ -1,14 +0,0 @@
-// RUN: mlir-pphlo-opt --expand-secret-gather --split-input-file %s | FileCheck %s
-
-func.func @main(%arg0: tensor<2x!pphlo.pub>, %arg1: tensor<1x!pphlo.sec>) -> (tensor>) {
- //CHECK-NOT: pphlo.gather
- //CHECK : pphlo.while
- %0 = "pphlo.gather"(%arg0, %arg1) {dimension_numbers = #pphlo.gather, indices_are_sorted = true, slice_sizes = array} : (tensor<2x!pphlo.pub>, tensor<1x!pphlo.sec>) -> tensor>
- return %0: tensor>
-}
-
-// -----
-func.func @main(%arg0: tensor<3x3x!pphlo.pub>, %arg1: tensor<2x!pphlo.sec>) -> (tensor<2x3x!pphlo.sec>) {
- %0 = "pphlo.gather"(%arg0, %arg1) {dimension_numbers = #pphlo.gather, indices_are_sorted = false, slice_sizes = array} : (tensor<3x3x!pphlo.pub>, tensor<2x!pphlo.sec>) -> tensor<2x3x!pphlo.sec>
- return %0 : tensor<2x3x!pphlo.sec>
-}
diff --git a/libspu/compiler/tests/hlo_to_pphlo_dynamic_slice.mlir b/libspu/compiler/tests/hlo_to_pphlo_dynamic_slice.mlir
new file mode 100644
index 00000000..3196a62c
--- /dev/null
+++ b/libspu/compiler/tests/hlo_to_pphlo_dynamic_slice.mlir
@@ -0,0 +1,7 @@
+// RUN: mlir-pphlo-opt --hlo-legalize-to-pphlo=input_vis_list=VIS_PUBLIC,VIS_SECRET --split-input-file %s | FileCheck %s
+
+func.func @main(%arg0: tensor<15xi32>,%arg1: tensor) -> (tensor<1xi32>) {
+ // CHECK: %0 = "pphlo.dynamic-slice"(%arg0, %arg1) {slice_sizes = array} : (tensor<15x!pphlo.pub>, tensor>) -> tensor<1x!pphlo.sec>
+ %0 = "stablehlo.dynamic_slice"(%arg0, %arg1) {slice_sizes = array} : (tensor<15xi32>, tensor) -> tensor<1xi32>
+ return %0 : tensor<1xi32>
+}
diff --git a/libspu/compiler/tests/no_expand_secret_gather.mlir b/libspu/compiler/tests/no_expand_secret_gather.mlir
deleted file mode 100644
index 155a4dca..00000000
--- a/libspu/compiler/tests/no_expand_secret_gather.mlir
+++ /dev/null
@@ -1,8 +0,0 @@
-// RUN: mlir-pphlo-opt --expand-secret-gather --split-input-file %s | FileCheck %s
-
-func.func @main(%arg0: tensor<2x!pphlo.pub>, %arg1: tensor<1x!pphlo.pub>) -> (tensor>) {
- //CHECK-NOT: pphlo.while
- //CHECK : pphlo.gather
- %0 = "pphlo.gather"(%arg0, %arg1) {dimension_numbers = #pphlo.gather, indices_are_sorted = true, slice_sizes = array} : (tensor<2x!pphlo.pub>, tensor<1x!pphlo.pub>) -> tensor>
- return %0: tensor>
-}
diff --git a/libspu/compiler/tests/optimize_sqrt_to_rsqrt.mlir b/libspu/compiler/tests/optimize_sqrt_to_rsqrt.mlir
index 88b7738d..6289c118 100644
--- a/libspu/compiler/tests/optimize_sqrt_to_rsqrt.mlir
+++ b/libspu/compiler/tests/optimize_sqrt_to_rsqrt.mlir
@@ -34,3 +34,29 @@ func.func @main(%arg0: tensor>, %arg1: tensor>)
return %4: tensor>
}
+// -----
+
+func.func @main(%arg0: tensor<3x4x!pphlo.sec>) -> tensor<3x4x!pphlo.sec> {
+ // CHECK: %[[RSQRT:.+]] = "pphlo.rsqrt"(%arg0) : (tensor<3x4x!pphlo.sec>) -> tensor<3x4x!pphlo.sec>
+ // CHECK: "pphlo.multiply"(%arg0, %[[RSQRT]]) : (tensor<3x4x!pphlo.sec>, tensor<3x4x!pphlo.sec>) -> tensor<3x4x!pphlo.sec>
+ %0 = "pphlo.sqrt"(%arg0) : (tensor<3x4x!pphlo.sec>) -> tensor<3x4x!pphlo.sec>
+ %1 = "pphlo.divide"(%arg0, %0) : (tensor<3x4x!pphlo.sec>, tensor<3x4x!pphlo.sec>) -> tensor<3x4x!pphlo.sec>
+ return %1 : tensor<3x4x!pphlo.sec>
+}
+
+// -----
+
+func.func @main(%arg0: tensor<3x4x!pphlo.sec>) -> tensor<3x4x!pphlo.sec> {
+ %0 = "pphlo.convert"(%arg0) : (tensor<3x4x!pphlo.sec>) -> tensor<3x4x!pphlo.sec>
+ %1 = "pphlo.reshape"(%arg0) : (tensor<3x4x!pphlo.sec>) -> tensor<3x4x1x!pphlo.sec>
+ %2 = "pphlo.transpose"(%1) {permutation = array} : (tensor<3x4x1x!pphlo.sec>) -> tensor<3x1x4x!pphlo.sec>
+ %3 = "pphlo.dot_general"(%2, %1) {dot_dimension_numbers = #pphlo.dot} : (tensor<3x1x4x!pphlo.sec>, tensor<3x4x1x!pphlo.sec>) -> tensor<3x!pphlo.sec>
+ %4 = "pphlo.convert"(%3) : (tensor<3x!pphlo.sec>) -> tensor<3x!pphlo.sec>
+ // CHECK: %[[RSQRT:.+]] = "pphlo.rsqrt"
+ // CHECK: %[[BCAST:.+]] = "pphlo.broadcast"(%[[RSQRT]]) {broadcast_dimensions = array} : (tensor<3x!pphlo.sec>) -> tensor<3x4x!pphlo.sec>
+ // CHECK: "pphlo.multiply"(%0, %[[BCAST]]) : (tensor<3x4x!pphlo.sec>, tensor<3x4x!pphlo.sec>) -> tensor<3x4x!pphlo.sec>
+ %5 = "pphlo.sqrt"(%4) : (tensor<3x!pphlo.sec>) -> tensor<3x!pphlo.sec>
+ %6 = "pphlo.broadcast"(%5) {broadcast_dimensions = array} : (tensor<3x!pphlo.sec>) -> tensor<3x4x!pphlo.sec>
+ %7 = "pphlo.divide"(%0, %6) : (tensor<3x4x!pphlo.sec>, tensor<3x4x!pphlo.sec>) -> tensor<3x4x!pphlo.sec>
+ return %7 : tensor<3x4x!pphlo.sec>
+}
diff --git a/libspu/device/pphlo/pphlo_executor.cc b/libspu/device/pphlo/pphlo_executor.cc
index 6ed81a4d..248fe9e4 100644
--- a/libspu/device/pphlo/pphlo_executor.cc
+++ b/libspu/device/pphlo/pphlo_executor.cc
@@ -481,36 +481,6 @@ void execute(OpExecutor *, SPUContext *sctx, SymbolScope *sscope,
opts);
}
-void execute(OpExecutor *, SPUContext *sctx, SymbolScope *sscope,
- mlir::pphlo::GatherOp &op, const ExecutionOptions &opts) {
- // If input is empty, short circuit
- auto operand = lookupValue(sscope, op.getOperand(), opts);
- auto start_indices = lookupValue(sscope, op.getStartIndices(), opts);
- if (operand.numel() == 0) {
- addValue(sscope, op.getResult(), operand, opts);
- return;
- }
-
- const auto &output_shape =
- op.getResult().getType().dyn_cast().getShape();
-
- const auto &dim_numbers = op.getDimensionNumbers();
-
- kernel::hlo::GatherConfig config;
- // Sizes ss;
- // convertDenseIntElementAttr(op.getSliceSizes(), ss);
- config.sliceSizes = op.getSliceSizes();
- config.indexVectorDim = dim_numbers.getIndexVectorDim();
- config.offsetDims = dim_numbers.getOffsetDims();
- config.collapsedSliceDims = dim_numbers.getCollapsedSliceDims();
- config.startIndexMap = dim_numbers.getStartIndexMap();
-
- addValue(
- sscope, op.getResult(),
- kernel::hlo::Gather(sctx, operand, start_indices, config, output_shape),
- opts);
-}
-
void execute(OpExecutor *executor, SPUContext *sctx, SymbolScope *sscope,
mlir::pphlo::SortOp &op, const ExecutionOptions &opts) {
auto sort_dim = op.getDimension();
diff --git a/libspu/device/pphlo/pphlo_executor_test.cc b/libspu/device/pphlo/pphlo_executor_test.cc
index 46362969..8d5a09ae 100644
--- a/libspu/device/pphlo/pphlo_executor_test.cc
+++ b/libspu/device/pphlo/pphlo_executor_test.cc
@@ -449,11 +449,11 @@ TEST_P(ExecutorTest, ReduceWindowStableHloTest) {
r.run(r.compileMHlo(R"(
func.func @main(%arg0: tensor<3x2xi32>) -> (tensor<2x2xi32>) {
- %0 = "mhlo.constant"() {value = dense<0> : tensor} : () -> tensor
- %1 = "mhlo.reduce_window"(%arg0, %0) ( {
+ %0 = "stablehlo.constant"() {value = dense<0> : tensor} : () -> tensor
+ %1 = "stablehlo.reduce_window"(%arg0, %0) ( {
^bb0(%arg1: tensor, %arg2: tensor): // no predecessors
- %2 = "mhlo.maximum"(%arg1, %arg2) : (tensor, tensor) -> tensor
- "mhlo.return"(%2) : (tensor) -> ()
+ %2 = "stablehlo.maximum"(%arg1, %arg2) : (tensor, tensor) -> tensor
+ "stablehlo.return"(%2) : (tensor) -> ()
}) {
base_dilations = dense<[2, 1]> : tensor<2xi64>,
padding = dense<[[2, 1], [0, 0]]> : tensor<2x2xi64>,
@@ -478,11 +478,11 @@ TEST_P(ExecutorTest, ReduceWindowStableHloTest2) {
r.run(r.compileMHlo(R"(
func.func @main(%arg0: tensor<3x2xi32>) -> (tensor<1x2xi32>) {
- %0 = "mhlo.constant"() {value = dense<0> : tensor} : () -> tensor
- %1 = "mhlo.reduce_window"(%arg0, %0) ( {
+ %0 = "stablehlo.constant"() {value = dense<0> : tensor} : () -> tensor
+ %1 = "stablehlo.reduce_window"(%arg0, %0) ( {
^bb0(%arg1: tensor, %arg2: tensor): // no predecessors
- %2 = "mhlo.maximum"(%arg1, %arg2) : (tensor, tensor) -> tensor
- "mhlo.return"(%2) : (tensor) -> ()
+ %2 = "stablehlo.maximum"(%arg1, %arg2) : (tensor, tensor) -> tensor
+ "stablehlo.return"(%2) : (tensor) -> ()
}) {
base_dilations = dense<[2, 1]> : tensor<2xi64>,
padding = dense<[[2, 1], [0, 0]]> : tensor<2x2xi64>,
@@ -583,11 +583,11 @@ TEST_P(ExecutorTest, ReduceWindowMaxIotaBaseDilation) {
r.run(r.compileMHlo(R"(
func.func @main(%arg0: tensor<4x4xi32>) -> (tensor<6x6xi32>) {
- %0 = "mhlo.constant"() {value = dense<0> : tensor} : () -> tensor
- %1 = "mhlo.reduce_window"(%arg0, %0) ( {
+ %0 = "stablehlo.constant"() {value = dense<0> : tensor} : () -> tensor
+ %1 = "stablehlo.reduce_window"(%arg0, %0) ( {
^bb0(%arg1: tensor, %arg2: tensor): // no predecessors
- %2 = "mhlo.maximum"(%arg1, %arg2) : (tensor, tensor) -> tensor
- "mhlo.return"(%2) : (tensor) -> ()
+ %2 = "stablehlo.maximum"(%arg1, %arg2) : (tensor, tensor) -> tensor
+ "stablehlo.return"(%2) : (tensor) -> ()
}) {
base_dilations = dense<2> : tensor<2xi64>,
padding = dense<0> : tensor<2x2xi64>,
@@ -615,11 +615,11 @@ TEST_P(ExecutorTest, ReduceWindowMaxIotaStrideBaseDilation) {
auto compiled = r.compileMHlo(R"(
func.func @main(%arg0: tensor<4x4xi32>) -> (tensor<3x3xi32>) {
- %0 = "mhlo.constant"() {value = dense<0> : tensor