Skip to content

Commit

Permalink
TF_USE_LEGACY_KERAS in bazel
Browse files Browse the repository at this point in the history
  • Loading branch information
Tombana committed Jun 16, 2024
1 parent 7d9a935 commit 911bd93
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 57 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/unittests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ jobs:
- name: Run FileCheck tests
run: bazelisk test larq_compute_engine/mlir/tests:all --test_output=all
- name: Run End2End tests
run: bazelisk test larq_compute_engine/tests:end2end_test --test_output=all
run: bazelisk test larq_compute_engine/tests:end2end_test --test_output=all --test_env=TF_USE_LEGACY_KERAS=1
- name: Run Strip dequantize op tests
run: bazelisk test larq_compute_engine/tests:strip_lcedequantize_test --test_output=all

Expand Down
2 changes: 1 addition & 1 deletion larq_compute_engine/mlir/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,7 @@ py_library(
],
deps = [
":_tf_tfl_flatbuffer",
lce_requirement("tensorflow-cpu"),
lce_requirement("tensorflow"),
lce_requirement("flatbuffers"),
],
)
Expand Down
3 changes: 2 additions & 1 deletion larq_compute_engine/requirements.in
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
tensorflow-cpu
tensorflow==2.16.1
tf-keras==2.16.0
tensorflow-datasets
larq
tqdm
Expand Down
52 changes: 28 additions & 24 deletions larq_compute_engine/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@ absl-py==2.1.0
# etils
# keras
# tensorboard
# tensorflow-cpu
# tensorflow
# tensorflow-datasets
# tensorflow-metadata
array-record==0.5.1
# via tensorflow-datasets
astunparse==1.6.3
# via tensorflow-cpu
# via tensorflow
certifi==2024.6.2
# via requests
charset-normalizer==3.3.2
Expand All @@ -34,23 +34,23 @@ etils==1.7.0
exceptiongroup==1.2.1
# via pytest
flatbuffers==24.3.25
# via tensorflow-cpu
# via tensorflow
fsspec==2024.6.0
# via etils
gast==0.5.4
# via tensorflow-cpu
# via tensorflow
google-pasta==0.2.0
# via tensorflow-cpu
# via tensorflow
googleapis-common-protos==1.63.1
# via -r larq_compute_engine/requirements.in
grpcio==1.64.1
# via
# tensorboard
# tensorflow-cpu
# tensorflow
h5py==3.11.0
# via
# keras
# tensorflow-cpu
# tensorflow
idna==3.7
# via requests
immutabledict==4.2.0
Expand All @@ -60,11 +60,11 @@ importlib-resources==6.4.0
iniconfig==2.0.0
# via pytest
keras==3.3.3
# via tensorflow-cpu
# via tensorflow
larq==0.13.3
# via -r larq_compute_engine/requirements.in
libclang==18.1.1
# via tensorflow-cpu
# via tensorflow
markdown==3.6
# via tensorboard
markdown-it-py==3.0.0
Expand All @@ -76,7 +76,7 @@ mdurl==0.1.2
ml-dtypes==0.3.2
# via
# keras
# tensorflow-cpu
# tensorflow
namex==0.0.8
# via keras
numpy==1.26.4
Expand All @@ -89,17 +89,17 @@ numpy==1.26.4
# opt-einsum
# pyarrow
# tensorboard
# tensorflow-cpu
# tensorflow
# tensorflow-datasets
opt-einsum==3.3.0
# via tensorflow-cpu
# via tensorflow
optree==0.11.0
# via keras
packaging==24.1
# via
# larq
# pytest
# tensorflow-cpu
# tensorflow
pluggy==1.5.0
# via pytest
promise==2.3
Expand All @@ -108,7 +108,7 @@ protobuf==3.20.3
# via
# googleapis-common-protos
# tensorboard
# tensorflow-cpu
# tensorflow
# tensorflow-datasets
# tensorflow-metadata
psutil==5.9.8
Expand All @@ -121,7 +121,7 @@ pytest==8.2.2
# via -r larq_compute_engine/requirements.in
requests==2.32.3
# via
# tensorflow-cpu
# tensorflow
# tensorflow-datasets
rich==13.7.1
# via keras
Expand All @@ -133,25 +133,29 @@ six==1.16.0
# google-pasta
# promise
# tensorboard
# tensorflow-cpu
# tensorflow
tensorboard==2.16.2
# via tensorflow-cpu
# via tensorflow
tensorboard-data-server==0.7.2
# via tensorboard
tensorflow-cpu==2.16.1
# via -r larq_compute_engine/requirements.in
tensorflow==2.16.1
# via
# -r larq_compute_engine/requirements.in
# tf-keras
tensorflow-datasets==4.9.6
# via -r larq_compute_engine/requirements.in
tensorflow-io-gcs-filesystem==0.37.0
# via tensorflow-cpu
# via tensorflow
tensorflow-metadata==1.15.0
# via tensorflow-datasets
termcolor==2.4.0
# via
# tensorflow-cpu
# tensorflow
# tensorflow-datasets
terminaltables==3.1.10
# via larq
tf-keras==2.16.0
# via -r larq_compute_engine/requirements.in
toml==0.10.2
# via tensorflow-datasets
tomli==2.0.1
Expand All @@ -166,7 +170,7 @@ typing-extensions==4.12.2
# etils
# optree
# simple-parsing
# tensorflow-cpu
# tensorflow
urllib3==2.2.1
# via requests
werkzeug==3.0.3
Expand All @@ -175,7 +179,7 @@ wheel==0.43.0
# via astunparse
wrapt==1.16.0
# via
# tensorflow-cpu
# tensorflow
# tensorflow-datasets
zipp==3.19.2
# via etils
Expand All @@ -184,4 +188,4 @@ zipp==3.19.2
setuptools==70.0.0
# via
# tensorboard
# tensorflow-cpu
# tensorflow
3 changes: 2 additions & 1 deletion larq_compute_engine/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@ py_test(
tf_requirement("numpy"),
lce_requirement("larq"),
lce_requirement("pytest"),
lce_requirement("tensorflow-cpu"),
lce_requirement("tensorflow"),
lce_requirement("tensorflow_datasets"),
lce_requirement("tf-keras"),
],
)

Expand Down
58 changes: 29 additions & 29 deletions larq_compute_engine/tflite/tests/BUILD
Original file line number Diff line number Diff line change
@@ -1,39 +1,39 @@
load("@pypi//:requirements.bzl", tf_requirement = "requirement")
load("@pypi_lce//:requirements.bzl", lce_requirement = "requirement")
load("@pypi//:requirements.bzl", tf_requirement="requirement")
load("@pypi_lce//:requirements.bzl", lce_requirement="requirement")

package(
default_visibility = ["//visibility:public"],
licenses = ["notice"], # Apache 2.0
default_visibility=["//visibility:public"],
licenses=["notice"], # Apache 2.0
)

cc_library(
name = "utils",
hdrs = [
name="utils",
hdrs=[
"utils.h",
],
deps = [
deps=[
"//larq_compute_engine/core:types",
"@org_tensorflow//tensorflow/lite/kernels/internal:types",
],
)

cc_library(
name = "bconv2d_op_model",
hdrs = [
name="bconv2d_op_model",
hdrs=[
"bconv2d_op_model.h",
],
deps = [
deps=[
":utils",
"//larq_compute_engine/tflite/kernels:lce_op_kernels",
"@flatbuffers",
],
)

cc_test(
name = "bconv2d_test",
size = "large",
srcs = ["bconv2d_test.cc"],
deps = [
name="bconv2d_test",
size="large",
srcs=["bconv2d_test.cc"],
deps=[
":bconv2d_op_model",
":utils",
"//larq_compute_engine/core/bitpacking:bitpack",
Expand All @@ -48,10 +48,10 @@ cc_test(
)

cc_test(
name = "bmaxpool_test",
size = "small",
srcs = ["bmaxpool_test.cc"],
deps = [
name="bmaxpool_test",
size="small",
srcs=["bmaxpool_test.cc"],
deps=[
":utils",
"//larq_compute_engine/core/bitpacking:utils",
"//larq_compute_engine/tflite/kernels:lce_op_kernels",
Expand All @@ -64,10 +64,10 @@ cc_test(
)

cc_test(
name = "quantization_test",
size = "small",
srcs = ["quantization_test.cc"],
deps = [
name="quantization_test",
size="small",
srcs=["quantization_test.cc"],
deps=[
":utils",
"//larq_compute_engine/tflite/kernels:lce_op_kernels",
"@com_google_googletest//:gtest",
Expand All @@ -79,22 +79,22 @@ cc_test(
)

py_test(
name = "interpreter_test",
size = "small",
srcs = ["interpreter_test.py"],
deps = [
name="interpreter_test",
size="small",
srcs=["interpreter_test.py"],
deps=[
"//larq_compute_engine/tflite/python:interpreter",
tf_requirement("numpy"),
lce_requirement("pytest"),
lce_requirement("tensorflow-cpu"),
lce_requirement("tensorflow"),
],
)

# COLLECTION OF ALL TFLITE CC TESTS
# each new cc test needs to be added here
test_suite(
name = "cc_tests",
tests = [
name="cc_tests",
tests=[
":bconv2d_test",
":bmaxpool_test",
":quantization_test",
Expand Down

0 comments on commit 911bd93

Please sign in to comment.