Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

⬆️ TensorFlow 2.9 #724

Merged
merged 11 commits into from
May 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .bazelversion
Original file line number Diff line number Diff line change
@@ -1 +1 @@
4.2.1
5.0.0
4 changes: 2 additions & 2 deletions .github/tools/release_linux.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ bazel build :build_pip_pkg \
--copt=-mavx \
--distinct_host_configuration=false \
--verbose_failures \
--crosstool_top=//third_party/toolchains/gcc7_manylinux2010-nvcc-cuda11:toolchain
--crosstool_top=@ubuntu20.04-gcc9_manylinux2014-cuda11.2-cudnn8.1-tensorrt7.2_config_cuda//crosstool:toolchain

# Package Whl
bazel-bin/build_pip_pkg artifacts

# Remove manylinux2010 config flags so that normal builds work as expected
# Remove manylinux2014 config flags so that normal builds work as expected
rm -f .lce_configure.bazelrc
8 changes: 4 additions & 4 deletions .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ jobs:
path: wheelhouse

manylinux-release-wheel:
name: Build release wheels for manylinux2010
name: Build release wheels for manylinux2014
runs-on: ubuntu-18.04
strategy:
matrix:
Expand All @@ -228,7 +228,7 @@ jobs:
continue-on-error: true
with:
credentials_json: ${{ secrets.gcs_bazel_cache }}
- name: Build manylinux2010 wheels
- name: Build manylinux2014 wheels
run: |
if [[ -n $GOOGLE_APPLICATION_CREDENTIALS ]]; then
echo -e 'build --remote_http_cache=https://storage.googleapis.com/plumerai-bazel-cache/lce-release-manylinux-python${{ matrix.python-version }}' >> .bazelrc.user
Expand All @@ -239,14 +239,14 @@ jobs:
-e GOOGLE_APPLICATION_CREDENTIALS=/tmp/gcloud-credentials.json \
-v $GOOGLE_APPLICATION_CREDENTIALS:/tmp/gcloud-credentials.json:ro \
-v ${PWD}:/compute-engine -w /compute-engine \
tensorflow/build:2.8-python${{ matrix.python-version }} \
tensorflow/build:2.9-python${{ matrix.python-version }} \
.github/tools/release_linux.sh

sudo apt-get -y -qq install patchelf --no-install-recommends
python -m pip install auditwheel --no-cache-dir

for f in artifacts/*.whl; do
auditwheel repair --plat manylinux2010_x86_64 $f
auditwheel repair --plat manylinux2014_x86_64 $f
done

ls -al wheelhouse/
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/unittests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ jobs:
if: github.ref != 'refs/heads/main'
shell: bash
- name: Install pip dependencies
run: pip install tensorflow-cpu~=2.8.0 larq~=0.11 larq_zoo~=2.0 pytest tensorflow_datasets~=4.4 flatbuffers==1.12 tqdm --no-cache-dir
run: pip install tensorflow-cpu~=2.9.0 larq~=0.11 larq_zoo~=2.0 pytest tensorflow_datasets~=4.4 flatbuffers==1.12 tqdm --no-cache-dir
- name: Run Interpreter test
run: bazelisk test larq_compute_engine/tflite/tests:interpreter_test --test_output=all
- name: Run FileCheck tests
Expand All @@ -101,7 +101,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
tf-version: [1.14.0, 1.15.5, 2.0.4, 2.1.4, 2.2.3, 2.3.3, 2.4.4, 2.5.3, 2.6.3, 2.7.1, 2.8.0]
tf-version: [1.14.0, 1.15.5, 2.0.4, 2.1.4, 2.2.3, 2.3.3, 2.4.4, 2.5.3, 2.6.4, 2.7.2, 2.8.1, 2.9.0]
if: "!contains(github.event.head_commit.message, 'ci-skip')"
steps:
- uses: actions/checkout@v3
Expand Down
24 changes: 4 additions & 20 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -67,26 +67,10 @@ if(COMPILE_BENCHMARK)
${LCE_SOURCE_DIR}/tflite/benchmark/lce_benchmark_tflite_model.h
${TFLITE_SOURCE_DIR}/tools/benchmark/benchmark_model.h
)
set(TFLITE_BENCHMARK_SRCS # from ${TFLITE_SOURCE_DIR}/tools/benchmark/CMakeLists.txt
${TENSORFLOW_SOURCE_DIR}/tensorflow/core/util/stats_calculator.cc
${TFLITE_SOURCE_DIR}/kernels/internal/utils/sparsity_format_converter.cc
${TFLITE_SOURCE_DIR}/profiling/memory_info.cc
${TFLITE_SOURCE_DIR}/profiling/memory_usage_monitor.cc
${TFLITE_SOURCE_DIR}/profiling/profile_summarizer.cc
${TFLITE_SOURCE_DIR}/profiling/profile_summary_formatter.cc
${TFLITE_SOURCE_DIR}/profiling/time.cc
${TFLITE_SOURCE_DIR}/tools/command_line_flags.cc
${TFLITE_SOURCE_DIR}/tools/benchmark/benchmark_model.cc
${TFLITE_SOURCE_DIR}/tools/benchmark/benchmark_performance_options.cc
${TFLITE_SOURCE_DIR}/tools/benchmark/benchmark_tflite_model.cc
${TFLITE_SOURCE_DIR}/tools/benchmark/benchmark_utils.cc
${TFLITE_SOURCE_DIR}/tools/benchmark/profiling_listener.cc
${TFLITE_SOURCE_DIR}/tools/delegates/default_execution_provider.cc
${TFLITE_SOURCE_DIR}/tools/delegates/delegate_provider.cc
${TFLITE_SOURCE_DIR}/tools/delegates/xnnpack_delegate_provider.cc
${TFLITE_SOURCE_DIR}/tools/evaluation/utils.cc
${TFLITE_SOURCE_DIR}/tools/tool_params.cc
)

get_directory_property(TFLITE_BENCHMARK_SRCS DIRECTORY ${TFLITE_SOURCE_DIR}/tools/benchmark DEFINITION TFLITE_BENCHMARK_SRCS)
list(FILTER TFLITE_BENCHMARK_SRCS EXCLUDE REGEX benchmark_main.cc)

add_executable(lce_benchmark_model
${TFLITE_BENCHMARK_SRCS}
${LCE_CORE_SRCS} ${LCE_CORE_HDRS}
Expand Down
1 change: 0 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,5 @@ RUN pip install six numpy --no-cache-dir
WORKDIR /compute-engine
COPY . .
RUN ./third_party/install_android.sh
ENV MANYLINUX2010=1
RUN ./configure.py
RUN bazelisk --version
6 changes: 3 additions & 3 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@ http_archive(
"//third_party/tensorflow_patches:disable_forced_mkl.patch",
"//third_party/tensorflow_patches:fix_armhf_xnnpack.patch",
],
sha256 = "66b953ae7fba61fd78969a2e24e350b26ec116cf2e6a7eb93d02c63939c6f9f7",
strip_prefix = "tensorflow-2.8.0",
sha256 = "8087cb0c529f04a4bfe480e49925cd64a904ad16d8ec66b98e2aacdfd53c80ff",
strip_prefix = "tensorflow-2.9.0",
urls = [
"https://github.com/tensorflow/tensorflow/archive/v2.8.0.tar.gz",
"https://github.com/tensorflow/tensorflow/archive/v2.9.0.tar.gz",
],
)

Expand Down
16 changes: 8 additions & 8 deletions larq_compute_engine/mlir/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ td_library(
srcs = ["transforms/op_removal_patterns.td"],
includes = ["/external/org_tensorflow"],
deps = [
"@llvm-project//mlir:StdOpsTdFiles",
"@llvm-project//mlir:FuncTdFiles",
"@org_tensorflow//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files",
],
)
Expand All @@ -43,7 +43,7 @@ td_library(
includes = ["/external/org_tensorflow"],
deps = [
":lce_ops_td_file",
"@llvm-project//mlir:StdOpsTdFiles",
"@llvm-project//mlir:FuncTdFiles",
"@org_tensorflow//tensorflow/compiler/mlir/lite:tensorflow_lite_ops_td_files",
],
)
Expand All @@ -54,7 +54,7 @@ td_library(
includes = ["/external/org_tensorflow"],
deps = [
":lce_ops_td_file",
"@llvm-project//mlir:StdOpsTdFiles",
"@llvm-project//mlir:FuncTdFiles",
"@org_tensorflow//tensorflow/compiler/mlir/lite:tensorflow_lite_ops_td_files",
"@org_tensorflow//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files",
],
Expand Down Expand Up @@ -182,7 +182,7 @@ gentbl_cc_library(
td_file = "transforms/bitpack_activations_patterns.td",
deps = [
":lce_ops_td_file",
"@llvm-project//mlir:StdOpsTdFiles",
"@llvm-project//mlir:FuncTdFiles",
"@org_tensorflow//tensorflow/compiler/mlir/lite:tensorflow_lite_ops_td_files",
],
)
Expand All @@ -199,7 +199,7 @@ gentbl_cc_library(
td_file = "transforms/bitpack_weights_patterns.td",
deps = [
":lce_ops_td_file",
"@llvm-project//mlir:StdOpsTdFiles",
"@llvm-project//mlir:FuncTdFiles",
"@org_tensorflow//tensorflow/compiler/mlir/lite:tensorflow_lite_ops_td_files",
"@org_tensorflow//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files",
],
Expand Down Expand Up @@ -288,7 +288,7 @@ cc_library(
"transforms/passes.h",
],
deps = [
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:FuncDialect",
"@org_tensorflow//tensorflow/compiler/mlir/tensorflow",
],
alwayslink = 1,
Expand All @@ -308,7 +308,7 @@ cc_library(
deps = [
":larq_compute_engine",
"//larq_compute_engine/core:types",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:FuncDialect",
"@org_tensorflow//tensorflow/compiler/mlir/lite:tensorflow_lite",
"@org_tensorflow//tensorflow/compiler/mlir/lite:tensorflow_lite_legalize_tf",
"@org_tensorflow//tensorflow/compiler/mlir/lite:validators",
Expand Down Expand Up @@ -429,7 +429,7 @@ cc_library(
"transforms/passes.h",
],
deps = [
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:FuncDialect",
],
alwayslink = 1,
)
Expand Down
2 changes: 1 addition & 1 deletion larq_compute_engine/mlir/ir/lce_ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def LarqDialect : Dialect {
//===----------------------------------------------------------------------===//

// Base class for the operation in this dialect
class LQ_Op<string mnemonic, list<OpTrait> traits = []> :
class LQ_Op<string mnemonic, list<Trait> traits = []> :
Op<LarqDialect, mnemonic, traits> {

let extraClassDeclaration = [{
Expand Down
11 changes: 5 additions & 6 deletions larq_compute_engine/mlir/lce_mlir_opt.cc
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
#include "larq_compute_engine/mlir/ir/lce_ops.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Quant/QuantOps.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Support/MlirOptMain.h"
#include "mlir/Tools/mlir-opt/MlirOptMain.h"
#include "mlir/Transforms/Passes.h"
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"

int main(int argc, char** argv) {
mlir::registerTransformsPasses();
mlir::DialectRegistry registry;
registry.insert<mlir::arith::ArithmeticDialect, mlir::StandardOpsDialect,
registry.insert<mlir::arith::ArithmeticDialect, mlir::func::FuncDialect,
mlir::quant::QuantizationDialect, mlir::TF::TensorFlowDialect,
mlir::TFL::TensorFlowLiteDialect, mlir::lq::LarqDialect>();
return failed(mlir::MlirOptMain(argc, argv,
"Larq Compute Engine pass driver\n", registry,
/*preloadDialectsInContext=*/false));
return failed(mlir::MlirOptMain(
argc, argv, "Larq Compute Engine pass driver\n", registry));
}
11 changes: 6 additions & 5 deletions larq_compute_engine/mlir/python/common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,11 @@ LCETarget GetLCETarget(const std::string& target_str) {
}
}

Status GetNumInputs(mlir::OwningModuleRef* module, int* num_inputs) {
Status GetNumInputs(mlir::OwningOpRef<mlir::ModuleOp>* module,
int* num_inputs) {
*num_inputs = 0;
mlir::FuncOp entry_function = nullptr;
for (auto func : module->get().getOps<mlir::FuncOp>()) {
mlir::func::FuncOp entry_function = nullptr;
for (auto func : module->get().getOps<mlir::func::FuncOp>()) {
if (auto tf_attrs =
func->getAttrOfType<mlir::DictionaryAttr>("tf.entry_function")) {
// TODO(jaesung): There could be multiple entry functions. Let's handle
Expand Down Expand Up @@ -70,13 +71,13 @@ Status GetNumInputs(mlir::OwningModuleRef* module, int* num_inputs) {
}

pybind11::bytes ConvertMLIRModuleToTFLiteFlatBuffer(
mlir::OwningModuleRef* module, mlir::MLIRContext& context,
mlir::OwningOpRef<mlir::ModuleOp>* module, mlir::MLIRContext& context,
const LCETarget target, const pybind11::object& default_ranges,
const std::unordered_set<std::string>& saved_model_tags,
llvm::StringRef saved_model_dir,
llvm::Optional<tensorflow::Session*> session, const int num_inputs,
const bool should_quantize, const bool mark_as_post_training_quant) {
mlir::TFL::QuantizationSpecs quant_specs;
mlir::quant::QuantizationSpecs quant_specs;
if (should_quantize) {
// Normally we'd only set `inference_type` to QINT8 when there are
// fake_quant nodes in the graph. However this did not work reliably, and
Expand Down
4 changes: 2 additions & 2 deletions larq_compute_engine/mlir/python/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ namespace tensorflow {

LCETarget GetLCETarget(const std::string& target_str);

Status GetNumInputs(mlir::OwningModuleRef* module, int* num_inputs);
Status GetNumInputs(mlir::OwningOpRef<mlir::ModuleOp>* module, int* num_inputs);

pybind11::bytes ConvertMLIRModuleToTFLiteFlatBuffer(
mlir::OwningModuleRef* module, mlir::MLIRContext& context,
mlir::OwningOpRef<mlir::ModuleOp>* module, mlir::MLIRContext& context,
const LCETarget target, const pybind11::object& default_ranges,
const std::unordered_set<std::string>& saved_model_tags,
llvm::StringRef saved_model_dir,
Expand Down
2 changes: 1 addition & 1 deletion larq_compute_engine/mlir/tests/bitpack-weights.mlir
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// RUN: lce-tf-opt %s -tfl-lce-bitpack-weights -verify-diagnostics | FileCheck %s

// CHECK-LABEL: @bitpack_bconv2d_filters
func @bitpack_bconv2d_filters(%arg0: tensor<256x32x32x1xi32>, %arg1: tensor<16xf32>, %arg2: tensor<16xf32>, %arg3: none) -> tensor<256x30x30x16xf32> {
func.func @bitpack_bconv2d_filters(%arg0: tensor<256x32x32x1xi32>, %arg1: tensor<16xf32>, %arg2: tensor<16xf32>, %arg3: none) -> tensor<256x30x30x16xf32> {
%cst = arith.constant dense<1.0> : tensor<16x3x3x3xf32>
%0 = "lq.Bconv2d"(%arg0, %cst, %arg1, %arg2, %arg3) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x1xi32>, tensor<16x3x3x3xf32>, tensor<16xf32>, tensor<16xf32>, none) -> tensor<256x30x30x16xf32>
return %0 : tensor<256x30x30x16xf32>
Expand Down
4 changes: 2 additions & 2 deletions larq_compute_engine/mlir/tests/const-fold.mlir
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// RUN: lce-tf-opt %s -canonicalize | FileCheck %s

// CHECK-LABEL: @quantize
func @quantize() -> (tensor<1x1x2x1xi32>, tensor<1x1x2x1xi32>) {
func.func @quantize() -> (tensor<1x1x2x1xi32>, tensor<1x1x2x1xi32>) {
%pos = arith.constant dense< 0.5> : tensor<1x1x2x32xf32>
%neg = arith.constant dense<-0.5> : tensor<1x1x2x32xf32>
%0 = "lq.Quantize"(%pos) {} : (tensor<1x1x2x32xf32>) -> tensor<1x1x2x1xi32>
Expand All @@ -14,7 +14,7 @@ func @quantize() -> (tensor<1x1x2x1xi32>, tensor<1x1x2x1xi32>) {
}

// CHECK-LABEL: @dequantize
func @dequantize() -> (tensor<1x1x2x32xf32>, tensor<1x1x2x32xf32>) {
func.func @dequantize() -> (tensor<1x1x2x32xf32>, tensor<1x1x2x32xf32>) {
%pos = arith.constant dense<0> : tensor<1x1x2x1xi32>
%neg = arith.constant dense<-1> : tensor<1x1x2x1xi32>
%0 = "lq.Dequantize"(%pos) {} : (tensor<1x1x2x1xi32>) -> tensor<1x1x2x32xf32>
Expand Down
12 changes: 6 additions & 6 deletions larq_compute_engine/mlir/tests/fuse_padding.mlir
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// RUN: lce-tf-opt %s -tfl-fuse-padding -verify-diagnostics | FileCheck %s

// CHECK-LABEL: @fuse_pad_into_conv_valid
func @fuse_pad_into_conv_valid(%arg0: tensor<1x64x64x8xf32>) -> tensor<1x64x64x16xf32> {
func.func @fuse_pad_into_conv_valid(%arg0: tensor<1x64x64x8xf32>) -> tensor<1x64x64x16xf32> {
%cst0 = arith.constant dense<[[0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi32>
%cst1 = arith.constant dense<1.0> : tensor<16x3x3x8xf32>
%cst2 = arith.constant dense<1.0> : tensor<16xf32>
Expand All @@ -14,7 +14,7 @@ func @fuse_pad_into_conv_valid(%arg0: tensor<1x64x64x8xf32>) -> tensor<1x64x64x1
}

// CHECK-LABEL: @fuse_padv2_into_conv_valid
func @fuse_padv2_into_conv_valid(%arg0: tensor<1x64x64x8xf32>) -> tensor<1x64x64x16xf32> {
func.func @fuse_padv2_into_conv_valid(%arg0: tensor<1x64x64x8xf32>) -> tensor<1x64x64x16xf32> {
%cst0 = arith.constant dense<[[0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi32>
%cst1 = arith.constant dense<0.0> : tensor<f32>
%cst2 = arith.constant dense<1.0> : tensor<16x3x3x8xf32>
Expand All @@ -28,7 +28,7 @@ func @fuse_padv2_into_conv_valid(%arg0: tensor<1x64x64x8xf32>) -> tensor<1x64x64
}

// CHECK-LABEL: @fuse_pad_into_dwconv_valid
func @fuse_pad_into_dwconv_valid(%arg0: tensor<1x64x64x16xf32>) -> tensor<1x64x64x16xf32> {
func.func @fuse_pad_into_dwconv_valid(%arg0: tensor<1x64x64x16xf32>) -> tensor<1x64x64x16xf32> {
%cst0 = arith.constant dense<[[0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi32>
%cst1 = arith.constant dense<1.0> : tensor<1x3x3x16xf32>
%cst2 = arith.constant dense<1.0> : tensor<16xf32>
Expand All @@ -41,7 +41,7 @@ func @fuse_pad_into_dwconv_valid(%arg0: tensor<1x64x64x16xf32>) -> tensor<1x64x6
}

// CHECK-LABEL: @do_not_fuse_padv2_into_conv_wrong_pad_value
func @do_not_fuse_padv2_into_conv_wrong_pad_value(%arg0: tensor<1x64x64x8xf32>) -> tensor<1x64x64x16xf32> {
func.func @do_not_fuse_padv2_into_conv_wrong_pad_value(%arg0: tensor<1x64x64x8xf32>) -> tensor<1x64x64x16xf32> {
%cst0 = arith.constant dense<[[0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi32>
%cst1 = arith.constant dense<1.0> : tensor<f32>
%cst2 = arith.constant dense<1.0> : tensor<16x3x3x8xf32>
Expand All @@ -54,7 +54,7 @@ func @do_not_fuse_padv2_into_conv_wrong_pad_value(%arg0: tensor<1x64x64x8xf32>)
}

// CHECK-LABEL: @do_not_fuse_pad_into_conv_same
func @do_not_fuse_pad_into_conv_same(%arg0: tensor<1x64x64x8xf32>) -> tensor<1x66x66x16xf32> {
func.func @do_not_fuse_pad_into_conv_same(%arg0: tensor<1x64x64x8xf32>) -> tensor<1x66x66x16xf32> {
%cst0 = arith.constant dense<[[0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi32>
%cst1 = arith.constant dense<1.0> : tensor<f32>
%cst2 = arith.constant dense<1.0> : tensor<16x3x3x8xf32>
Expand All @@ -67,7 +67,7 @@ func @do_not_fuse_pad_into_conv_same(%arg0: tensor<1x64x64x8xf32>) -> tensor<1x6
}

// CHECK-LABEL: @do_not_fuse_pad_into_dwconv_channelpad
func @do_not_fuse_pad_into_dwconv_channelpad(%arg0: tensor<1x64x64x12xf32>) -> tensor<1x64x64x16xf32> {
func.func @do_not_fuse_pad_into_dwconv_channelpad(%arg0: tensor<1x64x64x12xf32>) -> tensor<1x64x64x16xf32> {
%cst0 = arith.constant dense<[[0, 0], [1, 1], [1, 1], [1, 3]]> : tensor<4x2xi32>
%cst1 = arith.constant dense<1.0> : tensor<1x3x3x16xf32>
%cst2 = arith.constant dense<1.0> : tensor<16xf32>
Expand Down
8 changes: 4 additions & 4 deletions larq_compute_engine/mlir/tests/legalize-lce.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// RUN: lce-tf-opt %s -tfl-legalize-lce -lce-translate-tfl -verify-diagnostics | FileCheck %s --check-prefix=TRANSLATE

// CHECK-LABEL: @legalize_bconv2d
func @legalize_bconv2d(%arg0: tensor<256x32x32x1xi32>, %arg1: tensor<16x3x3x3xf32>, %arg2: tensor<16xf32>, %arg3: tensor<16xf32>, %arg4: none) -> tensor<256x30x30x16xf32> {
func.func @legalize_bconv2d(%arg0: tensor<256x32x32x1xi32>, %arg1: tensor<16x3x3x3xf32>, %arg2: tensor<16xf32>, %arg3: tensor<16xf32>, %arg4: none) -> tensor<256x30x30x16xf32> {
%0 = "lq.Bconv2d"(%arg0, %arg1, %arg2, %arg3, %arg4) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x1xi32>, tensor<16x3x3x3xf32>, tensor<16xf32>, tensor<16xf32>, none) -> tensor<256x30x30x16xf32>
return %0 : tensor<256x30x30x16xf32>

Expand All @@ -14,7 +14,7 @@ func @legalize_bconv2d(%arg0: tensor<256x32x32x1xi32>, %arg1: tensor<16x3x3x3xf3
}

// CHECK-LABEL: @legalize_bmax_pool2d
func @legalize_bmax_pool2d(%arg0: tensor<256x32x32x3xi32>) -> tensor<256x16x16x3xi32> {
func.func @legalize_bmax_pool2d(%arg0: tensor<256x32x32x3xi32>) -> tensor<256x16x16x3xi32> {
%0 = "lq.BMaxPool2d"(%arg0) {filter_height = 2 : i32, filter_width = 2 : i32, padding = "SAME", stride_height = 2 : i32, stride_width = 2 : i32} : (tensor<256x32x32x3xi32>) -> tensor<256x16x16x3xi32>
return %0 : tensor<256x16x16x3xi32>

Expand All @@ -26,7 +26,7 @@ func @legalize_bmax_pool2d(%arg0: tensor<256x32x32x3xi32>) -> tensor<256x16x16x3
}

// CHECK-LABEL: @legalize_quantize
func @legalize_quantize(%arg0: tensor<256x32x32x64xf32>) -> tensor<256x32x32x2xi32> {
func.func @legalize_quantize(%arg0: tensor<256x32x32x64xf32>) -> tensor<256x32x32x2xi32> {
%0 = "lq.Quantize"(%arg0) {} : (tensor<256x32x32x64xf32>) -> tensor<256x32x32x2xi32>
return %0 : tensor<256x32x32x2xi32>

Expand All @@ -38,7 +38,7 @@ func @legalize_quantize(%arg0: tensor<256x32x32x64xf32>) -> tensor<256x32x32x2xi
}

// CHECK-LABEL: @legalize_dequantize
func @legalize_dequantize(%arg0: tensor<256x32x32x2xi32>) -> tensor<256x32x32x64xf32> {
func.func @legalize_dequantize(%arg0: tensor<256x32x32x2xi32>) -> tensor<256x32x32x64xf32> {
%0 = "lq.Dequantize"(%arg0) {} : (tensor<256x32x32x2xi32>) -> tensor<256x32x32x64xf32>
return %0 : tensor<256x32x32x64xf32>

Expand Down
Loading